Skip to content

Commit

Permalink
Merge pull request #676 from Isotr0py/sdxl
Browse files Browse the repository at this point in the history
Fix RAM leak when loading SDXL model in lowram device
  • Loading branch information
kohya-ss authored Jul 30, 2023
2 parents 4072f72 + d9180c0 commit e20b6ac
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 10 deletions.
43 changes: 38 additions & 5 deletions library/sdxl_model_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
Expand Down Expand Up @@ -133,13 +136,43 @@ def convert_key(key):
return new_sd, logit_scale


def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
def _load_state_dict(model, state_dict, device, dtype=None):
# dtype will use fp32 as default
missing_keys = list(model.state_dict().keys() - state_dict.keys())
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())

# similar to model.load_state_dict()
if not missing_keys and not unexpected_keys:
for k in list(state_dict.keys()):
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
return '<All keys matched successfully>'

# error_msgs
error_msgs: List[str] = []
if missing_keys:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
if unexpected_keys:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))


def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
# model_version is reserved for future use
# dtype is reserved for full_fp16/bf16 integration

# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, device=map_location)
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
epoch = None
global_step = None
else:
Expand All @@ -156,16 +189,16 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):

# U-Net
print("building U-Net")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()

print("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = unet.load_state_dict(unet_sd)
info = _load_state_dict(unet, unet_sd, device=map_location)
print("U-Net: ", info)
del unet_sd

# Text Encoders
print("building text encoders")
Expand Down
11 changes: 6 additions & 5 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from typing import Optional
import torch
from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
Expand Down Expand Up @@ -66,7 +67,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
Expand All @@ -75,7 +76,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=variant, tokenizer=None)
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None)
except EnvironmentError as ex:
if variant is not None:
print("try to load fp32 model")
Expand All @@ -95,10 +96,10 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
del pipe

# Diffusers U-Net to original U-Net
original_unet = sdxl_original_unet.SdxlUNet2DConditionModel()
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
original_unet.load_state_dict(state_dict)
unet = original_unet
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
sdxl_model_util._load_state_dict(unet, state_dict, device=device)
print("U-Net converted to original U-Net")

logit_scale = None
Expand Down

0 comments on commit e20b6ac

Please sign in to comment.