Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for safetensors and LoRa. #2448

Merged
merged 2 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 61 additions & 18 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@

from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging


if is_safetensors_available():
import safetensors


logger = logging.get_logger(__name__)


LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"


class AttnProcsLayers(torch.nn.Module):
Expand Down Expand Up @@ -136,28 +141,53 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
weight_name = kwargs.pop("weight_name", None)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
if is_safetensors_available():
if weight_name is None:
weight_name = LORA_WEIGHT_NAME_SAFE
try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
if weight_name == LORA_WEIGHT_NAME_SAFE:
weight_name = None
if model_file is None:
if weight_name is None:
weight_name = LORA_WEIGHT_NAME
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict

Expand Down Expand Up @@ -195,8 +225,9 @@ def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weights_name: str = LORA_WEIGHT_NAME,
weights_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
):
r"""
Save an attention processor to a directory, so that it can be re-loaded using the
Expand All @@ -219,7 +250,13 @@ def save_attn_procs(
return

if save_function is None:
save_function = torch.save
if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})

else:
save_function = torch.save

os.makedirs(save_directory, exist_ok=True)

Expand All @@ -237,6 +274,12 @@ def save_attn_procs(
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)

if weights_name is None:
if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE
else:
weights_name = LORA_WEIGHT_NAME

# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))

Expand Down
60 changes: 60 additions & 0 deletions tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import gc
import os
import tempfile
import unittest

Expand Down Expand Up @@ -372,6 +373,65 @@ def test_lora_save_load(self):

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)

with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample

assert (sample - new_sample).abs().max() < 1e-4

# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4

def test_lora_save_load_safetensors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["attention_head_dim"] = (8, 16)

torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)

with torch.no_grad():
old_sample = model(**inputs_dict).sample

lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)

# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1

model.set_attn_processor(lora_attn_procs)

with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
Expand Down