diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index a7f03afe1..f490dfac8 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,6 +1,9 @@ import torch +from accelerate import init_empty_weights +from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer +from typing import List from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet @@ -133,13 +136,43 @@ def convert_key(key): return new_sd, logit_scale -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): +def _load_state_dict(model, state_dict, device, dtype=None): + # dtype will use fp32 as default + missing_keys = list(model.state_dict().keys() - state_dict.keys()) + unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) + + # similar to model.load_state_dict() + if not missing_keys and not unexpected_keys: + for k in list(state_dict.keys()): + set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) + return '' + + # error_msgs + error_msgs: List[str] = [] + if missing_keys: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys))) + if unexpected_keys: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys))) + + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + + +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): # model_version is reserved for future use + # dtype is reserved for full_fp16/bf16 integration # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - state_dict = load_file(ckpt_path, device=map_location) + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: @@ -156,16 +189,16 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): # U-Net print("building U-Net") - unet = sdxl_original_unet.SdxlUNet2DConditionModel() + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() print("loading U-Net from checkpoint") unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) - info = unet.load_state_dict(unet_sd) + info = _load_state_dict(unet, unet_sd, device=map_location) print("U-Net: ", info) - del unet_sd # Text Encoders print("building text encoders") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 6ff0d48f5..ebcc3d399 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -4,6 +4,7 @@ import os from typing import Optional import torch +from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet @@ -66,7 +67,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype) else: # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline @@ -75,7 +76,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: - pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=variant, tokenizer=None) + pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None) except EnvironmentError as ex: if variant is not None: print("try to load fp32 model") @@ -95,10 +96,10 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version del pipe # Diffusers U-Net to original U-Net - original_unet = sdxl_original_unet.SdxlUNet2DConditionModel() state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) - original_unet.load_state_dict(state_dict) - unet = original_unet + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + sdxl_model_util._load_state_dict(unet, state_dict, device=device) print("U-Net converted to original U-Net") logit_scale = None