From 1d412726b3d6f0ea4a6a8a2280f90839da971cab Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 19 Dec 2022 11:43:29 -0500 Subject: [PATCH 1/3] Add new prefix and postfic for captions --- library/basic_caption_gui.py | 81 +++++++++++++++++++++--------------- library/blip_caption_gui.py | 25 +++++++++-- library/common_gui.py | 18 +++++++- 3 files changed, 87 insertions(+), 37 deletions(-) diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index 377a4b3e5..d4d187b93 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -1,37 +1,40 @@ import gradio as gr from easygui import msgbox import subprocess -from .common_gui import get_folder_path +from .common_gui import get_folder_path, add_pre_postfix def caption_images( - caption_text_input, images_dir_input, overwrite_input, caption_file_ext + caption_text_input, images_dir_input, overwrite_input, caption_file_ext, prefix, postfix ): - # Check for caption_text_input - if caption_text_input == '': - msgbox('Caption text is missing...') - return - # Check for images_dir_input if images_dir_input == '': msgbox('Image folder is missing...') return - print( - f'Captioning files in {images_dir_input} with {caption_text_input}...' - ) - run_cmd = f'python "tools/caption.py"' - run_cmd += f' --caption_text="{caption_text_input}"' - if overwrite_input: - run_cmd += f' --overwrite' - if caption_file_ext != '': - run_cmd += f' --caption_file_ext="{caption_file_ext}"' - run_cmd += f' "{images_dir_input}"' + if not caption_text_input == '': + print( + f'Captioning files in {images_dir_input} with {caption_text_input}...' + ) + run_cmd = f'python "tools/caption.py"' + run_cmd += f' --caption_text="{caption_text_input}"' + if overwrite_input: + run_cmd += f' --overwrite' + if caption_file_ext != '': + run_cmd += f' --caption_file_ext="{caption_file_ext}"' + run_cmd += f' "{images_dir_input}"' - print(run_cmd) + print(run_cmd) - # Run the command - subprocess.run(run_cmd) + # Run the command + subprocess.run(run_cmd) + + if overwrite_input: + # Add prefix and postfix + add_pre_postfix(folder=images_dir_input, caption_file_ext=caption_file_ext, prefix=prefix, postfix=postfix) + else: + if not prefix == '' or not postfix == '': + msgbox('Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...') print('...captioning done') @@ -47,11 +50,34 @@ def gradio_basic_caption_gui_tab(): 'This utility will allow the creation of simple caption files for each images in a folder.' ) with gr.Row(): + images_dir_input = gr.Textbox( + label='Image folder to caption', + placeholder='Directory containing the images to caption', + interactive=True, + ) + button_images_dir_input = gr.Button( + '📂', elem_id='open_folder_small' + ) + button_images_dir_input.click( + get_folder_path, outputs=images_dir_input + ) + with gr.Row(): + prefix = gr.Textbox( + label='Prefix to add to txt caption', + placeholder='(Optional)', + interactive=True, + ) caption_text_input = gr.Textbox( label='Caption text', - placeholder='Eg: , by some artist', + placeholder='Eg: , by some artist. Leave empti if you just want to add pre or postfix', interactive=True, ) + postfix = gr.Textbox( + label='Postfix to add to txt caption', + placeholder='(Optional)', + interactive=True, + ) + with gr.Row(): overwrite_input = gr.Checkbox( label='Overwrite existing captions in folder', interactive=True, @@ -62,18 +88,6 @@ def gradio_basic_caption_gui_tab(): placeholder='(Optional) Default: .caption', interactive=True, ) - with gr.Row(): - images_dir_input = gr.Textbox( - label='Image folder to caption', - placeholder='Directory containing the images to caption', - interactive=True, - ) - button_images_dir_input = gr.Button( - '📂', elem_id='open_folder_small' - ) - button_images_dir_input.click( - get_folder_path, outputs=images_dir_input - ) caption_button = gr.Button('Caption images') caption_button.click( @@ -83,5 +97,6 @@ def gradio_basic_caption_gui_tab(): images_dir_input, overwrite_input, caption_file_ext, + prefix, postfix ], ) diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py index fe33ca7ea..5678978ad 100644 --- a/library/blip_caption_gui.py +++ b/library/blip_caption_gui.py @@ -1,8 +1,8 @@ import gradio as gr from easygui import msgbox import subprocess -from .common_gui import get_folder_path - +import os +from .common_gui import get_folder_path, add_pre_postfix def caption_images( train_data_dir, @@ -13,6 +13,8 @@ def caption_images( max_length, min_length, beam_search, + prefix, + postfix ): # Check for caption_text_input # if caption_text_input == "": @@ -42,6 +44,9 @@ def caption_images( # Run the command subprocess.run(run_cmd) + + # Add prefix and postfix + add_pre_postfix(folder=train_data_dir, caption_file_ext=caption_file_ext, prefix=prefix, postfix=postfix) print('...captioning done') @@ -68,12 +73,24 @@ def gradio_blip_caption_gui_tab(): button_train_data_dir_input.click( get_folder_path, outputs=train_data_dir ) - + with gr.Row(): caption_file_ext = gr.Textbox( label='Caption file extension', placeholder='(Optional) Default: .caption', interactive=True, ) + + prefix = gr.Textbox( + label='Prefix to add to BLIP caption', + placeholder='(Optional)', + interactive=True, + ) + + postfix = gr.Textbox( + label='Postfix to add to BLIP caption', + placeholder='(Optional)', + interactive=True, + ) batch_size = gr.Number( value=1, label='Batch size', interactive=True @@ -107,5 +124,7 @@ def gradio_blip_caption_gui_tab(): max_length, min_length, beam_search, + prefix, + postfix ], ) diff --git a/library/common_gui.py b/library/common_gui.py index 3979f7767..8d68c5817 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -1,4 +1,5 @@ from tkinter import filedialog, Tk +import os def get_file_path(file_path='', defaultextension='.json'): current_file_path = file_path @@ -56,4 +57,19 @@ def get_saveasfile_path(file_path='', defaultextension='.json'): print(file_path) - return file_path \ No newline at end of file + return file_path + +def add_pre_postfix(folder='', prefix='', postfix='', caption_file_ext='.caption'): + files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] + if not prefix == '': + prefix = f'{prefix} ' + if not postfix == '': + postfix = f' {postfix}' + + for file in files: + with open(os.path.join(folder, file), 'r+') as f: + content = f.read() + content = content.rstrip() + f.seek(0,0) + f.write(f'{prefix}{content}{postfix}') + f.close() \ No newline at end of file From 1f1dd5c4deb711f8ef08f82ae81a4ccb02e083b9 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 19 Dec 2022 21:50:05 -0500 Subject: [PATCH 2/3] Add support for: shuffle_caption, save_state, resume, prior_loss_weight, Fix issue with config open and save --- dreambooth_gui.py | 167 ++- library/basic_caption_gui.py | 23 +- library/blip_caption_gui.py | 18 +- library/common_gui.py | 44 +- library/convert_model_gui.py | 148 +- library/dataset_balancing_gui.py | 35 +- library/model_util.py | 2316 +++++++++++++++++------------- 7 files changed, 1636 insertions(+), 1115 deletions(-) diff --git a/dreambooth_gui.py b/dreambooth_gui.py index e8f370817..efd1af684 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -10,7 +10,9 @@ import subprocess import pathlib import shutil -from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) from library.basic_caption_gui import gradio_basic_caption_gui_tab from library.convert_model_gui import gradio_convert_model_tab from library.blip_caption_gui import gradio_blip_caption_gui_tab @@ -20,14 +22,14 @@ get_folder_path, remove_doublequote, get_file_path, - get_saveasfile_path + get_saveasfile_path, ) from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +document_symbol = '\U0001F4C4' # 📄 def save_configuration( @@ -60,7 +62,11 @@ def save_configuration( stop_text_encoder_training, use_8bit_adam, xformers, - save_model_as + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, ): original_file_path = file_path @@ -68,22 +74,14 @@ def save_configuration( if save_as_bool: print('Save as...') - # file_path = filesavebox( - # 'Select the config file to save', - # default='finetune.json', - # filetypes='*.json', - # ) file_path = get_saveasfile_path(file_path) else: print('Save...') if file_path == None or file_path == '': - # file_path = filesavebox( - # 'Select the config file to save', - # default='finetune.json', - # filetypes='*.json', - # ) file_path = get_saveasfile_path(file_path) + # print(file_path) + if file_path == None or file_path == '': return original_file_path # In case a file_path was provided and the user decide to cancel the open action @@ -116,7 +114,11 @@ def save_configuration( 'stop_text_encoder_training': stop_text_encoder_training, 'use_8bit_adam': use_8bit_adam, 'xformers': xformers, - 'save_model_as': save_model_as + 'save_model_as': save_model_as, + 'shuffle_caption': shuffle_caption, + 'save_state': save_state, + 'resume': resume, + 'prior_loss_weight': prior_loss_weight, } # Save the data to the selected file @@ -155,14 +157,18 @@ def open_configuration( stop_text_encoder_training, use_8bit_adam, xformers, - save_model_as + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, ): original_file_path = file_path file_path = get_file_path(file_path) + # print(file_path) - if file_path != '' and file_path != None: - print(file_path) + if not file_path == '' and not file_path == None: # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) @@ -204,7 +210,11 @@ def open_configuration( my_data.get('stop_text_encoder_training', stop_text_encoder_training), my_data.get('use_8bit_adam', use_8bit_adam), my_data.get('xformers', xformers), - my_data.get('save_model_as', save_model_as) + my_data.get('save_model_as', save_model_as), + my_data.get('shuffle_caption', shuffle_caption), + my_data.get('save_state', save_state), + my_data.get('resume', resume), + my_data.get('prior_loss_weight', prior_loss_weight), ) @@ -236,7 +246,11 @@ def train_model( stop_text_encoder_training_pct, use_8bit_adam, xformers, - save_model_as + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -360,6 +374,10 @@ def save_inference_file(output_dir, v2, v_parameterization): run_cmd += ' --use_8bit_adam' if xformers: run_cmd += ' --xformers' + if shuffle_caption: + run_cmd += ' --shuffle_caption' + if save_state: + run_cmd += ' --save_state' run_cmd += ( f' --pretrained_model_name_or_path={pretrained_model_name_or_path}' ) @@ -382,9 +400,15 @@ def save_inference_file(output_dir, v2, v_parameterization): run_cmd += f' --logging_dir={logging_dir}' run_cmd += f' --caption_extention={caption_extention}' if not stop_text_encoder_training == 0: - run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}' + run_cmd += ( + f' --stop_text_encoder_training={stop_text_encoder_training}' + ) if not save_model_as == 'same as source model': run_cmd += f' --save_model_as={save_model_as}' + if not resume == '': + run_cmd += f' --resume={resume}' + if not float(prior_loss_weight) == 1.0: + run_cmd += f' --prior_loss_weight={prior_loss_weight}' print(run_cmd) # Run the command @@ -392,7 +416,7 @@ def save_inference_file(output_dir, v2, v_parameterization): # check if output_dir/last is a folder... therefore it is a diffuser model last_dir = pathlib.Path(f'{output_dir}/last') - + if not last_dir.is_dir(): # Copy inference model for v2 if required save_inference_file(output_dir, v2, v_parameterization) @@ -472,8 +496,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): ) config_file_name = gr.Textbox( label='', - # placeholder="type the configuration file path or use the 'Open' button above to select it...", - interactive=False + placeholder="type the configuration file path or use the 'Open' button above to select it...", + interactive=True, ) # config_file_name.change( # remove_doublequote, @@ -491,13 +515,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): document_symbol, elem_id='open_folder_small' ) pretrained_model_name_or_path_fille.click( - get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input + get_file_path, + inputs=[pretrained_model_name_or_path_input], + outputs=pretrained_model_name_or_path_input, ) pretrained_model_name_or_path_folder = gr.Button( folder_symbol, elem_id='open_folder_small' ) pretrained_model_name_or_path_folder.click( - get_folder_path, outputs=pretrained_model_name_or_path_input + get_folder_path, + outputs=pretrained_model_name_or_path_input, ) model_list = gr.Dropdown( label='(Optional) Model Quick Pick', @@ -517,10 +544,10 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): 'same as source model', 'ckpt', 'diffusers', - "diffusers_safetensors", + 'diffusers_safetensors', 'safetensors', ], - value='same as source model' + value='same as source model', ) with gr.Row(): v2_input = gr.Checkbox(label='v2', value=True) @@ -607,7 +634,9 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): ) with gr.Tab('Training parameters'): with gr.Row(): - learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6) + learning_rate_input = gr.Textbox( + label='Learning rate', value=1e-6 + ) lr_scheduler_input = gr.Dropdown( label='LR Scheduler', choices=[ @@ -662,7 +691,9 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): with gr.Row(): seed_input = gr.Textbox(label='Seed', value=1234) max_resolution_input = gr.Textbox( - label='Max resolution', value='512,512', placeholder='512,512' + label='Max resolution', + value='512,512', + placeholder='512,512', ) with gr.Row(): caption_extention_input = gr.Textbox( @@ -676,27 +707,45 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): step=1, label='Stop text encoder training', ) - with gr.Row(): - full_fp16_input = gr.Checkbox( - label='Full fp16 training (experimental)', value=False - ) - no_token_padding_input = gr.Checkbox( - label='No token padding', value=False - ) - - gradient_checkpointing_input = gr.Checkbox( - label='Gradient checkpointing', value=False - ) with gr.Row(): enable_bucket_input = gr.Checkbox( label='Enable buckets', value=True ) - cache_latent_input = gr.Checkbox(label='Cache latent', value=True) + cache_latent_input = gr.Checkbox( + label='Cache latent', value=True + ) use_8bit_adam_input = gr.Checkbox( label='Use 8bit adam', value=True ) xformers_input = gr.Checkbox(label='Use xformers', value=True) - + with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(): + full_fp16_input = gr.Checkbox( + label='Full fp16 training (experimental)', value=False + ) + no_token_padding_input = gr.Checkbox( + label='No token padding', value=False + ) + + gradient_checkpointing_input = gr.Checkbox( + label='Gradient checkpointing', value=False + ) + + shuffle_caption = gr.Checkbox( + label='Shuffle caption', value=False + ) + save_state = gr.Checkbox(label='Save state', value=False) + with gr.Row(): + resume = gr.Textbox( + label='Resume', + placeholder='path to "last-state" state folder to resume from', + ) + resume_button = gr.Button('📂', elem_id='open_folder_small') + resume_button.click(get_folder_path, outputs=resume) + prior_loss_weight = gr.Number( + label='Prior loss weight', value=1.0 + ) + button_run = gr.Button('Train model') with gr.Tab('Utilities'): @@ -713,8 +762,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): gradio_dataset_balancing_tab() gradio_convert_model_tab() - - button_open_config.click( open_configuration, inputs=[ @@ -746,7 +793,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): stop_text_encoder_training_input, use_8bit_adam_input, xformers_input, - save_model_as_dropdown + save_model_as_dropdown, + shuffle_caption, + save_state, + resume, + prior_loss_weight, ], outputs=[ config_file_name, @@ -777,7 +828,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): stop_text_encoder_training_input, use_8bit_adam_input, xformers_input, - save_model_as_dropdown + save_model_as_dropdown, + shuffle_caption, + save_state, + resume, + prior_loss_weight, ], ) @@ -815,7 +870,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): stop_text_encoder_training_input, use_8bit_adam_input, xformers_input, - save_model_as_dropdown + save_model_as_dropdown, + shuffle_caption, + save_state, + resume, + prior_loss_weight, ], outputs=[config_file_name], ) @@ -852,7 +911,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): stop_text_encoder_training_input, use_8bit_adam_input, xformers_input, - save_model_as_dropdown + save_model_as_dropdown, + shuffle_caption, + save_state, + resume, + prior_loss_weight, ], outputs=[config_file_name], ) @@ -887,7 +950,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): stop_text_encoder_training_input, use_8bit_adam_input, xformers_input, - save_model_as_dropdown + save_model_as_dropdown, + shuffle_caption, + save_state, + resume, + prior_loss_weight, ], ) diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index d4d187b93..6f209373a 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -5,7 +5,12 @@ def caption_images( - caption_text_input, images_dir_input, overwrite_input, caption_file_ext, prefix, postfix + caption_text_input, + images_dir_input, + overwrite_input, + caption_file_ext, + prefix, + postfix, ): # Check for images_dir_input if images_dir_input == '': @@ -28,13 +33,20 @@ def caption_images( # Run the command subprocess.run(run_cmd) - + if overwrite_input: # Add prefix and postfix - add_pre_postfix(folder=images_dir_input, caption_file_ext=caption_file_ext, prefix=prefix, postfix=postfix) + add_pre_postfix( + folder=images_dir_input, + caption_file_ext=caption_file_ext, + prefix=prefix, + postfix=postfix, + ) else: if not prefix == '' or not postfix == '': - msgbox('Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...') + msgbox( + 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...' + ) print('...captioning done') @@ -97,6 +109,7 @@ def gradio_basic_caption_gui_tab(): images_dir_input, overwrite_input, caption_file_ext, - prefix, postfix + prefix, + postfix, ], ) diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py index 5678978ad..6583b9f9f 100644 --- a/library/blip_caption_gui.py +++ b/library/blip_caption_gui.py @@ -4,6 +4,7 @@ import os from .common_gui import get_folder_path, add_pre_postfix + def caption_images( train_data_dir, caption_file_ext, @@ -14,7 +15,7 @@ def caption_images( min_length, beam_search, prefix, - postfix + postfix, ): # Check for caption_text_input # if caption_text_input == "": @@ -44,9 +45,14 @@ def caption_images( # Run the command subprocess.run(run_cmd) - + # Add prefix and postfix - add_pre_postfix(folder=train_data_dir, caption_file_ext=caption_file_ext, prefix=prefix, postfix=postfix) + add_pre_postfix( + folder=train_data_dir, + caption_file_ext=caption_file_ext, + prefix=prefix, + postfix=postfix, + ) print('...captioning done') @@ -79,13 +85,13 @@ def gradio_blip_caption_gui_tab(): placeholder='(Optional) Default: .caption', interactive=True, ) - + prefix = gr.Textbox( label='Prefix to add to BLIP caption', placeholder='(Optional)', interactive=True, ) - + postfix = gr.Textbox( label='Postfix to add to BLIP caption', placeholder='(Optional)', @@ -125,6 +131,6 @@ def gradio_blip_caption_gui_tab(): min_length, beam_search, prefix, - postfix + postfix, ], ) diff --git a/library/common_gui.py b/library/common_gui.py index 8d68c5817..bf6f291a8 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -1,16 +1,20 @@ from tkinter import filedialog, Tk import os + def get_file_path(file_path='', defaultextension='.json'): current_file_path = file_path # print(f'current file path: {current_file_path}') - + root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() - file_path = filedialog.askopenfilename(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension) + file_path = filedialog.askopenfilename( + filetypes=(('Config files', '*.json'), ('All files', '*')), + defaultextension=defaultextension, + ) root.destroy() - + if file_path == '': file_path = current_file_path @@ -26,50 +30,58 @@ def remove_doublequote(file_path): def get_folder_path(folder_path=''): current_folder_path = folder_path - + root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() folder_path = filedialog.askdirectory() root.destroy() - + if folder_path == '': folder_path = current_folder_path return folder_path + def get_saveasfile_path(file_path='', defaultextension='.json'): current_file_path = file_path # print(f'current file path: {current_file_path}') - + root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() - save_file_path = filedialog.asksaveasfile(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension) + save_file_path = filedialog.asksaveasfile( + filetypes=(('Config files', '*.json'), ('All files', '*')), + defaultextension=defaultextension, + ) root.destroy() - - # file_path = file_path.name - if file_path == '': + + # print(save_file_path) + + if save_file_path == None: file_path = current_file_path else: print(save_file_path.name) file_path = save_file_path.name - print(file_path) + # print(file_path) return file_path -def add_pre_postfix(folder='', prefix='', postfix='', caption_file_ext='.caption'): - files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] + +def add_pre_postfix( + folder='', prefix='', postfix='', caption_file_ext='.caption' +): + files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] if not prefix == '': prefix = f'{prefix} ' if not postfix == '': postfix = f' {postfix}' - + for file in files: with open(os.path.join(folder, file), 'r+') as f: content = f.read() content = content.rstrip() - f.seek(0,0) + f.seek(0, 0) f.write(f'{prefix}{content}{postfix}') - f.close() \ No newline at end of file + f.close() diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py index 8f6751acc..6d9c99935 100644 --- a/library/convert_model_gui.py +++ b/library/convert_model_gui.py @@ -8,37 +8,45 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +document_symbol = '\U0001F4C4' # 📄 -def convert_model(source_model_input, source_model_type, target_model_folder_input, target_model_name_input, target_model_type, target_save_precision_type): + +def convert_model( + source_model_input, + source_model_type, + target_model_folder_input, + target_model_name_input, + target_model_type, + target_save_precision_type, +): # Check for caption_text_input - if source_model_type == "": - msgbox("Invalid source model type") + if source_model_type == '': + msgbox('Invalid source model type') return - + # Check if source model exist if os.path.isfile(source_model_input): print('The provided source model is a file') elif os.path.isdir(source_model_input): print('The provided model is a folder') else: - msgbox("The provided source model is neither a file nor a folder") + msgbox('The provided source model is neither a file nor a folder') return - + # Check if source model exist if os.path.isdir(target_model_folder_input): print('The provided model folder exist') else: - msgbox("The provided target folder does not exist") + msgbox('The provided target folder does not exist') return - + run_cmd = f'.\\venv\Scripts\python.exe "tools/convert_diffusers20_original_sd.py"' - + v1_models = [ - 'runwayml/stable-diffusion-v1-5', - 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5', + 'CompVis/stable-diffusion-v1-4', ] - + # check if v1 models if str(source_model_type) in v1_models: print('SD v1 model specified. Setting --v1 parameter') @@ -46,54 +54,76 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp else: print('SD v2 model specified. Setting --v2 parameter') run_cmd += ' --v2' - + if not target_save_precision_type == 'unspecified': run_cmd += f' --{target_save_precision_type}' - - if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors": + + if ( + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' + ): run_cmd += f' --reference_model="{source_model_type}"' - + if target_model_type == 'diffuser_safetensors': run_cmd += ' --use_safetensors' - + run_cmd += f' "{source_model_input}"' - - if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors": - target_model_path = os.path.join(target_model_folder_input, target_model_name_input) + + if ( + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' + ): + target_model_path = os.path.join( + target_model_folder_input, target_model_name_input + ) run_cmd += f' "{target_model_path}"' else: - target_model_path = os.path.join(target_model_folder_input, f'{target_model_name_input}.{target_model_type}') + target_model_path = os.path.join( + target_model_folder_input, + f'{target_model_name_input}.{target_model_type}', + ) run_cmd += f' "{target_model_path}"' - + print(run_cmd) - + # Run the command subprocess.run(run_cmd) - - if not target_model_type == "diffuser" or target_model_type == "diffuser_safetensors": - - v2_models = ['stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base',] - v_parameterization =[ - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2',] - + + if ( + not target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' + ): + + v2_models = [ + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + ] + v_parameterization = [ + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + ] + if str(source_model_type) in v2_models: - inference_file = os.path.join(target_model_folder_input, f'{target_model_name_input}.yaml') + inference_file = os.path.join( + target_model_folder_input, f'{target_model_name_input}.yaml' + ) print(f'Saving v2-inference.yaml as {inference_file}') shutil.copy( f'./v2_inference/v2-inference.yaml', f'{inference_file}', ) - + if str(source_model_type) in v_parameterization: - inference_file = os.path.join(target_model_folder_input, f'{target_model_name_input}.yaml') + inference_file = os.path.join( + target_model_folder_input, f'{target_model_name_input}.yaml' + ) print(f'Saving v2-inference-v.yaml as {inference_file}') shutil.copy( f'./v2_inference/v2-inference-v.yaml', f'{inference_file}', ) + # parser = argparse.ArgumentParser() # parser.add_argument("--v1", action='store_true', # help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') @@ -138,22 +168,27 @@ def gradio_convert_model_tab(): button_source_model_dir.click( get_folder_path, outputs=source_model_input ) - + button_source_model_file = gr.Button( document_symbol, elem_id='open_folder_small' ) button_source_model_file.click( - get_file_path, inputs=[source_model_input], outputs=source_model_input + get_file_path, + inputs=[source_model_input], + outputs=source_model_input, ) - - source_model_type = gr.Dropdown(label="Source model type", choices=[ + + source_model_type = gr.Dropdown( + label='Source model type', + choices=[ 'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-base', 'stabilityai/stable-diffusion-2-1', 'stabilityai/stable-diffusion-2', 'runwayml/stable-diffusion-v1-5', 'CompVis/stable-diffusion-v1-4', - ],) + ], + ) with gr.Row(): target_model_folder_input = gr.Textbox( label='Target model folder', @@ -166,30 +201,37 @@ def gradio_convert_model_tab(): button_target_model_folder.click( get_folder_path, outputs=target_model_folder_input ) - + target_model_name_input = gr.Textbox( label='Target model name', placeholder='target model name...', interactive=True, ) - target_model_type = gr.Dropdown(label="Target model type", choices=[ + target_model_type = gr.Dropdown( + label='Target model type', + choices=[ 'diffuser', 'diffuser_safetensors', 'ckpt', 'safetensors', - ],) - target_save_precision_type = gr.Dropdown(label="Target model precison", choices=[ - 'unspecified', - 'fp16', - 'bf16', - 'float' - ], value='unspecified') - - + ], + ) + target_save_precision_type = gr.Dropdown( + label='Target model precison', + choices=['unspecified', 'fp16', 'bf16', 'float'], + value='unspecified', + ) + convert_button = gr.Button('Convert model') convert_button.click( convert_model, - inputs=[source_model_input, source_model_type, target_model_folder_input, target_model_name_input, target_model_type, target_save_precision_type + inputs=[ + source_model_input, + source_model_type, + target_model_folder_input, + target_model_name_input, + target_model_type, + target_save_precision_type, ], ) diff --git a/library/dataset_balancing_gui.py b/library/dataset_balancing_gui.py index e6fdd937b..d109cfbd1 100644 --- a/library/dataset_balancing_gui.py +++ b/library/dataset_balancing_gui.py @@ -13,7 +13,7 @@ def dataset_balancing(concept_repeats, folder, insecure): - + if not concept_repeats > 0: # Display an error message if the total number of repeats is not a valid integer msgbox('Please enter a valid integer for the total number of repeats.') @@ -72,23 +72,35 @@ def dataset_balancing(concept_repeats, folder, insecure): os.rename(old_name, new_name) else: - print(f"Skipping folder {subdir} because it does not match kohya_ss expected syntax...") + print( + f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...' + ) msgbox('Dataset balancing completed...') + def warning(insecure): if insecure: - if boolbox(f'WARNING!!! You have asked to rename non kohya_ss _ folders...\n\nAre you sure you want to do that?', choices=("Yes, I like danger", "No, get me out of here")): + if boolbox( + f'WARNING!!! You have asked to rename non kohya_ss _ folders...\n\nAre you sure you want to do that?', + choices=('Yes, I like danger', 'No, get me out of here'), + ): return True else: return False + def gradio_dataset_balancing_tab(): with gr.Tab('Dataset balancing'): - gr.Markdown('This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.') - gr.Markdown('WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!') + gr.Markdown( + 'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.' + ) + gr.Markdown( + 'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!' + ) with gr.Row(): - select_dataset_folder_input = gr.Textbox(label="Dataset folder", + select_dataset_folder_input = gr.Textbox( + label='Dataset folder', placeholder='Folder containing the concepts folders to balance...', interactive=True, ) @@ -106,10 +118,17 @@ def gradio_dataset_balancing_tab(): label='Training steps per concept per epoch', ) with gr.Accordion('Advanced options', open=False): - insecure = gr.Checkbox(value=False, label="DANGER!!! -- Insecure folder renaming -- DANGER!!!") + insecure = gr.Checkbox( + value=False, + label='DANGER!!! -- Insecure folder renaming -- DANGER!!!', + ) insecure.change(warning, inputs=insecure, outputs=insecure) balance_button = gr.Button('Balance dataset') balance_button.click( dataset_balancing, - inputs=[total_repeats_number, select_dataset_folder_input, insecure], + inputs=[ + total_repeats_number, + select_dataset_folder_input, + insecure, + ], ) diff --git a/library/model_util.py b/library/model_util.py index f34530252..29d442034 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,7 +5,12 @@ import os import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) from safetensors.torch import load_file, save_file # DiffUsers版StableDiffusionのモデルパラメータ @@ -36,8 +41,8 @@ V2_UNET_PARAMS_CONTEXT_DIM = 1024 # Diffusersの設定を読み込むための参照モデル -DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" -DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" +DIFFUSERS_REF_MODEL_ID_V1 = 'runwayml/stable-diffusion-v1-5' +DIFFUSERS_REF_MODEL_ID_V2 = 'stabilityai/stable-diffusion-2-1' # region StableDiffusion->Diffusersの変換コード @@ -45,588 +50,845 @@ def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return '.'.join(path.split('.')[n_shave_prefix_segments:]) + else: + return '.'.join(path.split('.')[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace('in_layers.0', 'norm1') + new_item = new_item.replace('in_layers.2', 'conv1') + + new_item = new_item.replace('out_layers.0', 'norm2') + new_item = new_item.replace('out_layers.3', 'conv2') + + new_item = new_item.replace('emb_layers.1', 'time_emb_proj') + new_item = new_item.replace('skip_connection', 'conv_shortcut') + + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({'old': old_item, 'new': new_item}) - return mapping + return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace('nin_shortcut', 'conv_shortcut') + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({'old': old_item, 'new': new_item}) - return mapping + return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({'old': old_item, 'new': new_item}) - return mapping + return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") + new_item = new_item.replace('norm.weight', 'group_norm.weight') + new_item = new_item.replace('norm.bias', 'group_norm.bias') - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace('q.weight', 'query.weight') + new_item = new_item.replace('q.bias', 'query.bias') - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace('k.weight', 'key.weight') + new_item = new_item.replace('k.bias', 'key.bias') - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace('v.weight', 'value.weight') + new_item = new_item.replace('v.bias', 'value.bias') - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) - mapping.append({"old": old_item, "new": new_item}) + mapping.append({'old': old_item, 'new': new_item}) - return mapping + return mapping def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, ): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance( + paths, list + ), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = ( + (-1, channels) if len(old_tensor.shape) == 3 else (-1) + ) + + num_heads = old_tensor.shape[0] // config['num_head_channels'] // 3 + + old_tensor = old_tensor.reshape( + (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] + ) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map['query']] = query.reshape(target_shape) + checkpoint[path_map['key']] = key.reshape(target_shape) + checkpoint[path_map['value']] = value.reshape(target_shape) + + for path in paths: + new_path = path['new'] + + # These have already been assigned + if ( + attention_paths_to_split is not None + and new_path in attention_paths_to_split + ): + continue + + # Global renaming happens here + new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0') + new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0') + new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1') + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace( + replacement['old'], replacement['new'] + ) + + # proj_attn.weight has to be converted from conv 1D to linear + if 'proj_attn.weight' in new_path: + checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path['old']] - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ['query.weight', 'key.weight', 'value.weight'] + for key in keys: + if '.'.join(key.split('.')[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif 'proj_attn.weight' in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ['proj_in.weight', 'proj_out.weight'] + for key in keys: + if '.'.join(key.split('.')[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) +def convert_ldm_unet_checkpoint(v2, checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = 'model.diffusion_model.' + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, '')] = checkpoint.pop(key) - for path in paths: - new_path = path["new"] + new_checkpoint = {} - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue + new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict[ + 'time_embed.0.weight' + ] + new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict[ + 'time_embed.0.bias' + ] + new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict[ + 'time_embed.2.weight' + ] + new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict[ + 'time_embed.2.bias' + ] - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + new_checkpoint['conv_in.weight'] = unet_state_dict[ + 'input_blocks.0.0.weight' + ] + new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias'] + + new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight'] + new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias'] + new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight'] + new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias'] + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + { + '.'.join(layer.split('.')[:2]) + for layer in unet_state_dict + if 'input_blocks' in layer + } + ) + input_blocks = { + layer_id: [ + key + for key in unet_state_dict + if f'input_blocks.{layer_id}.' in key + ] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + { + '.'.join(layer.split('.')[:2]) + for layer in unet_state_dict + if 'middle_block' in layer + } + ) + middle_blocks = { + layer_id: [ + key + for key in unet_state_dict + if f'middle_block.{layer_id}.' in key + ] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len( + { + '.'.join(layer.split('.')[:2]) + for layer in unet_state_dict + if 'output_blocks' in layer + } + ) + output_blocks = { + layer_id: [ + key + for key in unet_state_dict + if f'output_blocks.{layer_id}.' in key + ] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config['layers_per_block'] + 1) + layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1) + + resnets = [ + key + for key in input_blocks[i] + if f'input_blocks.{i}.0' in key + and f'input_blocks.{i}.0.op' not in key + ] + attentions = [ + key for key in input_blocks[i] if f'input_blocks.{i}.1' in key + ] - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) + if f'input_blocks.{i}.0.op.weight' in unet_state_dict: + new_checkpoint[ + f'down_blocks.{block_id}.downsamplers.0.conv.weight' + ] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight') + new_checkpoint[ + f'down_blocks.{block_id}.downsamplers.0.conv.bias' + ] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias') - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] + paths = renew_resnet_paths(resnets) + meta_path = { + 'old': f'input_blocks.{i}.0', + 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}', + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + 'old': f'input_blocks.{i}.1', + 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}', + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config + ) -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config + ) + attentions_paths = renew_attention_paths(attentions) + meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'} + assign_to_checkpoint( + attentions_paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) -def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + for i in range(num_output_blocks): + block_id = i // (config['layers_per_block'] + 1) + layer_in_block_id = i % (config['layers_per_block'] + 1) + output_block_layers = [ + shave_segments(name, 2) for name in output_blocks[i] + ] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split('.')[0], shave_segments( + layer, 1 + ) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [ + key + for key in output_blocks[i] + if f'output_blocks.{i}.0' in key + ] + attentions = [ + key + for key in output_blocks[i] + if f'output_blocks.{i}.1' in key + ] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = { + 'old': f'output_blocks.{i}.0', + 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}', + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ['conv.bias', 'conv.weight'] in output_block_list.values(): + index = list(output_block_list.values()).index( + ['conv.bias', 'conv.weight'] + ) + new_checkpoint[ + f'up_blocks.{block_id}.upsamplers.0.conv.bias' + ] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias'] + new_checkpoint[ + f'up_blocks.{block_id}.upsamplers.0.conv.weight' + ] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight'] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + 'old': f'output_blocks.{i}.1', + 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}', + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + else: + resnet_0_paths = renew_resnet_paths( + output_block_layers, n_shave_prefix_segments=1 + ) + for path in resnet_0_paths: + old_path = '.'.join(['output_blocks', str(i), path['old']]) + new_path = '.'.join( + [ + 'up_blocks', + str(block_id), + 'resnets', + str(layer_in_block_id), + path['new'], + ] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する + if v2: + linear_transformer_to_conv(new_checkpoint) + return new_checkpoint -def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - unet_key = "model.diffusion_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = 'first_stage_model.' + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, '')] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint['encoder.conv_in.weight'] = vae_state_dict[ + 'encoder.conv_in.weight' + ] + new_checkpoint['encoder.conv_in.bias'] = vae_state_dict[ + 'encoder.conv_in.bias' + ] + new_checkpoint['encoder.conv_out.weight'] = vae_state_dict[ + 'encoder.conv_out.weight' + ] + new_checkpoint['encoder.conv_out.bias'] = vae_state_dict[ + 'encoder.conv_out.bias' + ] + new_checkpoint['encoder.conv_norm_out.weight'] = vae_state_dict[ + 'encoder.norm_out.weight' + ] + new_checkpoint['encoder.conv_norm_out.bias'] = vae_state_dict[ + 'encoder.norm_out.bias' ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() - - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" + new_checkpoint['decoder.conv_in.weight'] = vae_state_dict[ + 'decoder.conv_in.weight' + ] + new_checkpoint['decoder.conv_in.bias'] = vae_state_dict[ + 'decoder.conv_in.bias' + ] + new_checkpoint['decoder.conv_out.weight'] = vae_state_dict[ + 'decoder.conv_out.weight' + ] + new_checkpoint['decoder.conv_out.bias'] = vae_state_dict[ + 'decoder.conv_out.bias' + ] + new_checkpoint['decoder.conv_norm_out.weight'] = vae_state_dict[ + 'decoder.norm_out.weight' + ] + new_checkpoint['decoder.conv_norm_out.bias'] = vae_state_dict[ + 'decoder.norm_out.bias' + ] + + new_checkpoint['quant_conv.weight'] = vae_state_dict['quant_conv.weight'] + new_checkpoint['quant_conv.bias'] = vae_state_dict['quant_conv.bias'] + new_checkpoint['post_quant_conv.weight'] = vae_state_dict[ + 'post_quant_conv.weight' + ] + new_checkpoint['post_quant_conv.bias'] = vae_state_dict[ + 'post_quant_conv.bias' + ] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len( + { + '.'.join(layer.split('.')[:3]) + for layer in vae_state_dict + if 'encoder.down' in layer + } + ) + down_blocks = { + layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] + for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len( + { + '.'.join(layer.split('.')[:3]) + for layer in vae_state_dict + if 'decoder.up' in layer + } + ) + up_blocks = { + layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] + for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [ + key + for key in down_blocks[i] + if f'down.{i}' in key and f'down.{i}.downsample' not in key ] - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] + if f'encoder.down.{i}.downsample.conv.weight' in vae_state_dict: + new_checkpoint[ + f'encoder.down_blocks.{i}.downsamplers.0.conv.weight' + ] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.weight') + new_checkpoint[ + f'encoder.down_blocks.{i}.downsamplers.0.conv.bias' + ] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.bias') - if len(attentions): - paths = renew_attention_paths(attentions) + paths = renew_vae_resnet_paths(resnets) meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + 'old': f'down.{i}.block', + 'new': f'down_blocks.{i}.resnets', } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する - if v2: - linear_transformer_to_conv(new_checkpoint) - return new_checkpoint + mid_resnets = [key for key in vae_state_dict if 'encoder.mid.block' in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [ + key for key in mid_resnets if f'encoder.mid.block_{i}' in key + ] + paths = renew_vae_resnet_paths(resnets) + meta_path = { + 'old': f'mid.block_{i}', + 'new': f'mid_block.resnets.{i - 1}', + } + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + mid_attentions = [ + key for key in vae_state_dict if 'encoder.mid.attn' in key ] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key + for key in up_blocks[block_id] + if f'up.{block_id}' in key and f'up.{block_id}.upsample' not in key + ] - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] + if f'decoder.up.{block_id}.upsample.conv.weight' in vae_state_dict: + new_checkpoint[ + f'decoder.up_blocks.{i}.upsamplers.0.conv.weight' + ] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.weight'] + new_checkpoint[ + f'decoder.up_blocks.{i}.upsamplers.0.conv.bias' + ] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.bias'] - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + paths = renew_vae_resnet_paths(resnets) + meta_path = { + 'old': f'up.{block_id}.block', + 'new': f'up_blocks.{i}.resnets', + } + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + mid_resnets = [key for key in vae_state_dict if 'decoder.mid.block' in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [ + key for key in mid_resnets if f'decoder.mid.block_{i}' in key + ] - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + paths = renew_vae_resnet_paths(resnets) + meta_path = { + 'old': f'mid.block_{i}', + 'new': f'mid_block.resnets.{i - 1}', + } + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint + mid_attentions = [ + key for key in vae_state_dict if 'decoder.mid.attn' in key + ] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint def create_unet_diffusers_config(v2): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params - - block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - ) - - return config + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [ + UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT + ] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + 'CrossAttnDownBlock2D' + if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS + else 'DownBlock2D' + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + 'CrossAttnUpBlock2D' + if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS + else 'UpBlock2D' + ) + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM + if not v2 + else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS + if not v2 + else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + ) + + return config def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ['DownEncoderBlock2D'] * len(block_out_channels) + up_block_types = ['UpDecoderBlock2D'] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] - return text_model_dict + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith('cond_stage_model.transformer'): + text_model_dict[ + key[len('cond_stage_model.transformer.') :] + ] = checkpoint[key] + return text_model_dict def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith("cond_stage_model"): - return None - - # common conversion - key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") - key = key.replace("cond_stage_model.model.", "text_model.") - - if "resblocks" in key: - # resblocks conversion - key = key.replace(".resblocks.", ".layers.") - if ".ln_" in key: - key = key.replace(".ln_", ".layer_norm") - elif ".mlp." in key: - key = key.replace(".c_fc.", ".fc1.") - key = key.replace(".c_proj.", ".fc2.") - elif '.attn.out_proj' in key: - key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") - elif '.attn.in_proj' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in SD: {key}") - elif '.positional_embedding' in key: - key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") - elif '.text_projection' in key: - key = None # 使われない??? - elif '.logit_scale' in key: - key = None # 使われない??? - elif '.token_embedding' in key: - key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") - elif '.ln_final' in key: - key = key.replace(".ln_final", ".final_layer_norm") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if '.resblocks.23.' in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if '.resblocks.23.' in key: - continue - if '.resblocks' in key and '.attn.in_proj_' in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) - - key_suffix = ".weight" if "weight" in key else ".bias" - key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") - key_pfx = key_pfx.replace("_weight", "") - key_pfx = key_pfx.replace("_bias", "") - key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") - new_sd[key_pfx + "q_proj" + key_suffix] = values[0] - new_sd[key_pfx + "k_proj" + key_suffix] = values[1] - new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - - # position_idsの追加 - new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) - return new_sd + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith('cond_stage_model'): + return None + + # common conversion + key = key.replace( + 'cond_stage_model.model.transformer.', 'text_model.encoder.' + ) + key = key.replace('cond_stage_model.model.', 'text_model.') + + if 'resblocks' in key: + # resblocks conversion + key = key.replace('.resblocks.', '.layers.') + if '.ln_' in key: + key = key.replace('.ln_', '.layer_norm') + elif '.mlp.' in key: + key = key.replace('.c_fc.', '.fc1.') + key = key.replace('.c_proj.', '.fc2.') + elif '.attn.out_proj' in key: + key = key.replace('.attn.out_proj.', '.self_attn.out_proj.') + elif '.attn.in_proj' in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f'unexpected key in SD: {key}') + elif '.positional_embedding' in key: + key = key.replace( + '.positional_embedding', + '.embeddings.position_embedding.weight', + ) + elif '.text_projection' in key: + key = None # 使われない??? + elif '.logit_scale' in key: + key = None # 使われない??? + elif '.token_embedding' in key: + key = key.replace( + '.token_embedding.weight', '.embeddings.token_embedding.weight' + ) + elif '.ln_final' in key: + key = key.replace('.ln_final', '.final_layer_norm') + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if '.resblocks.23.' in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if '.resblocks.23.' in key: + continue + if '.resblocks' in key and '.attn.in_proj_' in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = '.weight' if 'weight' in key else '.bias' + key_pfx = key.replace( + 'cond_stage_model.model.transformer.resblocks.', + 'text_model.encoder.layers.', + ) + key_pfx = key_pfx.replace('_weight', '') + key_pfx = key_pfx.replace('_bias', '') + key_pfx = key_pfx.replace('.attn.in_proj', '.self_attn.') + new_sd[key_pfx + 'q_proj' + key_suffix] = values[0] + new_sd[key_pfx + 'k_proj' + key_suffix] = values[1] + new_sd[key_pfx + 'v_proj' + key_suffix] = values[2] + + # position_idsの追加 + new_sd['text_model.embeddings.position_ids'] = torch.Tensor( + [list(range(max_length))] + ).to(torch.int64) + return new_sd + # endregion @@ -634,549 +896,649 @@ def convert_key(key): # region Diffusers->StableDiffusion の変換コード # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) + def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] + keys = list(checkpoint.keys()) + tf_keys = ['proj_in.weight', 'proj_out.weight'] + for key in keys: + if '.'.join(key.split('.')[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("time_embed.0.weight", "time_embedding.linear_1.weight"), - ("time_embed.0.bias", "time_embedding.linear_1.bias"), - ("time_embed.2.weight", "time_embedding.linear_2.weight"), - ("time_embed.2.bias", "time_embedding.linear_2.bias"), - ("input_blocks.0.0.weight", "conv_in.weight"), - ("input_blocks.0.0.bias", "conv_in.bias"), - ("out.0.weight", "conv_norm_out.weight"), - ("out.0.bias", "conv_norm_out.bias"), - ("out.2.weight", "conv_out.weight"), - ("out.2.bias", "conv_out.bias"), - ] - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0", "norm1"), - ("in_layers.2", "conv1"), - ("out_layers.0", "norm2"), - ("out_layers.3", "conv2"), - ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut"), - ] - - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ('time_embed.0.weight', 'time_embedding.linear_1.weight'), + ('time_embed.0.bias', 'time_embedding.linear_1.bias'), + ('time_embed.2.weight', 'time_embedding.linear_2.weight'), + ('time_embed.2.bias', 'time_embedding.linear_2.bias'), + ('input_blocks.0.0.weight', 'conv_in.weight'), + ('input_blocks.0.0.bias', 'conv_in.bias'), + ('out.0.weight', 'conv_norm_out.weight'), + ('out.0.bias', 'conv_norm_out.bias'), + ('out.2.weight', 'conv_out.weight'), + ('out.2.bias', 'conv_out.bias'), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ('in_layers.0', 'norm1'), + ('in_layers.2', 'conv1'), + ('out_layers.0', 'norm2'), + ('out_layers.3', 'conv2'), + ('emb_layers.1', 'time_emb_proj'), + ('skip_connection', 'conv_shortcut'), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.' + sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.' + unet_conversion_map_layer.append( + (sd_down_res_prefix, hf_down_res_prefix) + ) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.' + sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.' + unet_conversion_map_layer.append( + (sd_down_atn_prefix, hf_down_atn_prefix) + ) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.' + sd_up_res_prefix = f'output_blocks.{3*i + j}.0.' + unet_conversion_map_layer.append( + (sd_up_res_prefix, hf_up_res_prefix) + ) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.' + sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.' + unet_conversion_map_layer.append( + (sd_up_atn_prefix, hf_up_atn_prefix) + ) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.' + sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.' + unet_conversion_map_layer.append( + (sd_downsample_prefix, hf_downsample_prefix) + ) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' + sd_upsample_prefix = ( + f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.' + ) + unet_conversion_map_layer.append( + (sd_upsample_prefix, hf_upsample_prefix) + ) + + hf_mid_atn_prefix = 'mid_block.attentions.0.' + sd_mid_atn_prefix = 'middle_block.1.' + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if "resnets" in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - - if v2: - conv_transformer_to_linear(new_state_dict) - - return new_state_dict + hf_mid_res_prefix = f'mid_block.resnets.{j}.' + sd_mid_res_prefix = f'middle_block.{2*j}.' + unet_conversion_map_layer.append( + (sd_mid_res_prefix, hf_mid_res_prefix) + ) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if 'resnets' in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict # ================# # VAE Conversion # # ================# + def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) + return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("nin_shortcut", "conv_shortcut"), - ("norm_out", "conv_norm_out"), - ("mid.attn_1.", "mid_block.attentions.0."), - ] - - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." - sd_down_prefix = f"encoder.down.{i}.block.{j}." - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - - if i < 3: - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." - sd_downsample_prefix = f"down.{i}.downsample." - vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), - ] - - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if "attentions" in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") - new_state_dict[k] = reshape_weight_for_sd(v) - - return new_state_dict + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ('nin_shortcut', 'conv_shortcut'), + ('norm_out', 'conv_norm_out'), + ('mid.attn_1.', 'mid_block.attentions.0.'), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.' + sd_down_prefix = f'encoder.down.{i}.block.{j}.' + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.' + sd_downsample_prefix = f'down.{i}.downsample.' + vae_conversion_map.append( + (sd_downsample_prefix, hf_downsample_prefix) + ) + + hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' + sd_upsample_prefix = f'up.{3-i}.upsample.' + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.' + sd_up_prefix = f'decoder.up.{3-i}.block.{j}.' + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f'mid_block.resnets.{i}.' + sd_mid_res_prefix = f'mid.block_{i+1}.' + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ('norm.', 'group_norm.'), + ('q.', 'query.'), + ('k.', 'key.'), + ('v.', 'value.'), + ('proj_out.', 'proj_attn.'), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if 'attentions' in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ['q', 'k', 'v', 'proj_out'] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f'mid.attn_1.{weight_name}.weight' in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict # endregion # region 自作のモデル読み書きなど + def is_safetensors(path): - return os.path.splitext(path)[1].lower() == '.safetensors' + return os.path.splitext(path)[1].lower() == '.safetensors' def load_checkpoint_with_text_encoder_conversion(ckpt_path): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), - ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), - ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') - ] - - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path, "cpu") - else: - checkpoint = torch.load(ckpt_path, map_location="cpu") - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ( + 'cond_stage_model.transformer.embeddings.', + 'cond_stage_model.transformer.text_model.embeddings.', + ), + ( + 'cond_stage_model.transformer.encoder.', + 'cond_stage_model.transformer.text_model.encoder.', + ), + ( + 'cond_stage_model.transformer.final_layer_norm.', + 'cond_stage_model.transformer.text_model.final_layer_norm.', + ), + ] + + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path, 'cpu') else: - state_dict = checkpoint - checkpoint = None + checkpoint = torch.load(ckpt_path, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + checkpoint = None - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from):] - key_reps.append((key, new_key)) + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] - return checkpoint, state_dict + return checkpoint, state_dict # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if dtype is not None: - for k, v in state_dict.items(): - if type(v) is torch.Tensor: - state_dict[k] = v.to(dtype) - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) - - unet = UNet2DConditionModel(**unet_config) - info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) - - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) - - vae = AutoencoderKL(**vae_config) - info = vae.load_state_dict(converted_vae_checkpoint) - print("loadint vae:", info) - - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=512, - torch_dtype="float32", - transformers_version="4.25.0.dev0", + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if dtype is not None: + for k, v in state_dict.items(): + if type(v) is torch.Tensor: + state_dict[k] = v.to(dtype) + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + v2, state_dict, unet_config + ) + + unet = UNet2DConditionModel(**unet_config) + info = unet.load_state_dict(converted_unet_checkpoint) + print('loading u-net:', info) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + state_dict, vae_config ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) - - return text_model, vae, unet - - -def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): - def convert_key(key): - # position_idsの除去 - if ".position_ids" in key: - return None - - # common - key = key.replace("text_model.encoder.", "transformer.") - key = key.replace("text_model.", "") - if "layers" in key: - # resblocks conversion - key = key.replace(".layers.", ".resblocks.") - if ".layer_norm" in key: - key = key.replace(".layer_norm", ".ln_") - elif ".mlp." in key: - key = key.replace(".fc1.", ".c_fc.") - key = key.replace(".fc2.", ".c_proj.") - elif '.self_attn.out_proj' in key: - key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") - elif '.self_attn.' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in DiffUsers model: {key}") - elif '.position_embedding' in key: - key = key.replace("embeddings.position_embedding.weight", "positional_embedding") - elif '.token_embedding' in key: - key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif 'final_layer_norm' in key: - key = key.replace("final_layer_norm", "ln_final") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if 'layers' in key and 'q_proj' in key: - # 三つを結合 - key_q = key - key_k = key.replace("q_proj", "k_proj") - key_v = key.replace("q_proj", "v_proj") - - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") - new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") - keys = list(new_sd.keys()) + + vae = AutoencoderKL(**vae_config) + info = vae.load_state_dict(converted_vae_checkpoint) + print('loadint vae:', info) + + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2( + state_dict, 77 + ) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act='gelu', + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type='clip_text_model', + projection_dim=512, + torch_dtype='float32', + transformers_version='4.25.0.dev0', + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1( + state_dict + ) + text_model = CLIPTextModel.from_pretrained( + 'openai/clip-vit-large-patch14' + ) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print('loading text encoder:', info) + + return text_model, vae, unet + + +def convert_text_encoder_state_dict_to_sd_v2( + checkpoint, make_dummy_weights=False +): + def convert_key(key): + # position_idsの除去 + if '.position_ids' in key: + return None + + # common + key = key.replace('text_model.encoder.', 'transformer.') + key = key.replace('text_model.', '') + if 'layers' in key: + # resblocks conversion + key = key.replace('.layers.', '.resblocks.') + if '.layer_norm' in key: + key = key.replace('.layer_norm', '.ln_') + elif '.mlp.' in key: + key = key.replace('.fc1.', '.c_fc.') + key = key.replace('.fc2.', '.c_proj.') + elif '.self_attn.out_proj' in key: + key = key.replace('.self_attn.out_proj.', '.attn.out_proj.') + elif '.self_attn.' in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f'unexpected key in DiffUsers model: {key}') + elif '.position_embedding' in key: + key = key.replace( + 'embeddings.position_embedding.weight', 'positional_embedding' + ) + elif '.token_embedding' in key: + key = key.replace( + 'embeddings.token_embedding.weight', 'token_embedding.weight' + ) + elif 'final_layer_norm' in key: + key = key.replace('final_layer_norm', 'ln_final') + return key + + keys = list(checkpoint.keys()) + new_sd = {} for key in keys: - if key.startswith("transformer.resblocks.22."): - new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] - # Diffusersに含まれない重みを作っておく - new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) - new_sd['logit_scale'] = torch.tensor(1) + # attnの変換 + for key in keys: + if 'layers' in key and 'q_proj' in key: + # 三つを結合 + key_q = key + key_k = key.replace('q_proj', 'k_proj') + key_v = key.replace('q_proj', 'v_proj') + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace( + 'text_model.encoder.layers.', 'transformer.resblocks.' + ) + new_key = new_key.replace('.self_attn.q_proj.', '.attn.in_proj_') + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + print( + 'make dummy weights for resblock.23, text_projection and logit scale.' + ) + keys = list(new_sd.keys()) + for key in keys: + if key.startswith('transformer.resblocks.22.'): + new_sd[key.replace('.22.', '.23.')] = new_sd[ + key + ].clone() # copyしないとsafetensorsの保存で落ちる + + # Diffusersに含まれない重みを作っておく + new_sd['text_projection'] = torch.ones( + (1024, 1024), + dtype=new_sd[keys[0]].dtype, + device=new_sd[keys[0]].device, + ) + new_sd['logit_scale'] = torch.tensor(1) - return new_sd + return new_sd -def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False +def save_stable_diffusion_checkpoint( + v2, + output_file, + text_encoder, + unet, + ckpt_path, + epochs, + steps, + save_dtype=None, + vae=None, +): + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion( + ckpt_path + ) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if 'state_dict' in state_dict: + del state_dict['state_dict'] else: - strict = True - if "state_dict" in state_dict: - del state_dict["state_dict"] - else: - # 新しく作る - assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" - checkpoint = {} - state_dict = {} - strict = False - - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert not strict or key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd("model.diffusion_model.", unet_state_dict) - - # Convert the text encoder model - if v2: - make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) - update_sd("cond_stage_model.model.", text_enc_dict) - else: - text_enc_dict = text_encoder.state_dict() - update_sd("cond_stage_model.transformer.", text_enc_dict) - - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd("first_stage_model.", vae_dict) - - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {'state_dict': state_dict} - - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] - - new_ckpt['epoch'] = epochs - new_ckpt['global_step'] = steps - - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file) - - return key_count - - -def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): - if pretrained_model_name_or_path is None: - # load default settings for v1/v2 + # 新しく作る + assert ( + vae is not None + ), 'VAE is required to save a checkpoint without a given checkpoint' + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert ( + not strict or key in state_dict + ), f'Illegal key in save SD: {key}' + if save_dtype is not None: + v = v.detach().clone().to('cpu').to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd('model.diffusion_model.', unet_state_dict) + + # Convert the text encoder model if v2: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + make_dummy = ( + ckpt_path is None + ) # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2( + text_encoder.state_dict(), make_dummy + ) + update_sd('cond_stage_model.model.', text_enc_dict) else: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + text_enc_dict = text_encoder.state_dict() + update_sd('cond_stage_model.transformer.', text_enc_dict) - scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") - if vae is None: - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd('first_stage_model.', vae_dict) - pipeline = StableDiffusionPipeline( - unet=unet, - text_encoder=text_encoder, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {'state_dict': state_dict} + if 'epoch' in checkpoint: + epochs += checkpoint['epoch'] + if 'global_step' in checkpoint: + steps += checkpoint['global_step'] -VAE_PREFIX = "first_stage_model." + new_ckpt['epoch'] = epochs + new_ckpt['global_step'] = steps + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) -def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) - except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) - return vae + return key_count - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith(".bin"): - # SD 1.5 VAE on Huggingface - vae_sd = torch.load(vae_id, map_location="cpu") - converted_vae_checkpoint = vae_sd - else: - # StableDiffusion - vae_model = torch.load(vae_id, map_location="cpu") - vae_sd = vae_model['state_dict'] - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) +def save_diffusers_checkpoint( + v2, + output_dir, + text_encoder, + unet, + pretrained_model_name_or_path, + vae=None, + use_safetensors=False, +): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + + scheduler = DDIMScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder='scheduler' + ) + tokenizer = CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, subfolder='tokenizer' + ) + if vae is None: + vae = AutoencoderKL.from_pretrained( + pretrained_model_name_or_path, subfolder='vae' + ) + + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae + +VAE_PREFIX = 'first_stage_model.' + + +def load_vae(vae_id, dtype): + print(f'load VAE: {vae_id}') + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained( + vae_id, subfolder=None, torch_dtype=dtype + ) + except EnvironmentError as e: + print(f'exception occurs in loading vae: {e}') + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained( + vae_id, subfolder='vae', torch_dtype=dtype + ) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith('.bin'): + # SD 1.5 VAE on Huggingface + vae_sd = torch.load(vae_id, map_location='cpu') + converted_vae_checkpoint = vae_sd + else: + # StableDiffusion + vae_model = torch.load(vae_id, map_location='cpu') + vae_sd = vae_model['state_dict'] + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + vae_sd, vae_config + ) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae def get_epoch_ckpt_name(use_safetensors, epoch): - return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") + return f'epoch-{epoch:06d}' + ( + '.safetensors' if use_safetensors else '.ckpt' + ) def get_last_ckpt_name(use_safetensors): - return f"last" + (".safetensors" if use_safetensors else ".ckpt") + return f'last' + ('.safetensors' if use_safetensors else '.ckpt') # endregion -def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): - max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) +def make_bucket_resolutions( + max_reso, min_size=256, max_size=1024, divisible=64 +): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) - resos = set() + resos = set() - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) - size += divisible + size += divisible - resos = list(resos) - resos.sort() + resos = list(resos) + resos.sort() - aspect_ratios = [w / h for w, h in resos] - return resos, aspect_ratios + aspect_ratios = [w / h for w, h in resos] + return resos, aspect_ratios if __name__ == '__main__': - resos, aspect_ratios = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) - print(aspect_ratios) - - ars = set() - for ar in aspect_ratios: - if ar in ars: - print("error! duplicate ar:", ar) - ars.add(ar) + resos, aspect_ratios = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + print(aspect_ratios) + + ars = set() + for ar in aspect_ratios: + if ar in ars: + print('error! duplicate ar:', ar) + ars.add(ar) From 69558b59513ab6cd46cd59edc79a76958bed7c63 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 19 Dec 2022 21:51:52 -0500 Subject: [PATCH 3/3] Update readme --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index a5d7c7718..5b946fca3 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,9 @@ Drop by the discord server for support: https://discord.com/channels/10415185624 ## Change history +* 12/19 (v18.4) update: + - Add support for shuffle_caption, save_state, resume, prior_loss_weight under "Advanced Configuration" section + - Fix issue with open/save config not working properly * 12/19 (v18.3) update: - fix stop encoder training issue * 12/19 (v18.2) update: