Skip to content

Commit

Permalink
Adding support for safetensors and LoRa. (huggingface#2448)
Browse files Browse the repository at this point in the history
* Adding support for `safetensors` and LoRa.

* Adding metadata.
  • Loading branch information
Narsil authored and w4ffl35 committed Apr 14, 2023
1 parent f311de3 commit cdf42b8
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 18 deletions.
79 changes: 61 additions & 18 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@

from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging


if is_safetensors_available():
import safetensors


logger = logging.get_logger(__name__)


LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"


class AttnProcsLayers(torch.nn.Module):
Expand Down Expand Up @@ -136,28 +141,53 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
weight_name = kwargs.pop("weight_name", None)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
if is_safetensors_available():
if weight_name is None:
weight_name = LORA_WEIGHT_NAME_SAFE
try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
if weight_name == LORA_WEIGHT_NAME_SAFE:
weight_name = None
if model_file is None:
if weight_name is None:
weight_name = LORA_WEIGHT_NAME
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict

Expand Down Expand Up @@ -195,8 +225,9 @@ def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weights_name: str = LORA_WEIGHT_NAME,
weights_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
):
r"""
Save an attention processor to a directory, so that it can be re-loaded using the
Expand All @@ -219,7 +250,13 @@ def save_attn_procs(
return

if save_function is None:
save_function = torch.save
if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})

else:
save_function = torch.save

os.makedirs(save_directory, exist_ok=True)

Expand All @@ -237,6 +274,12 @@ def save_attn_procs(
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)

if weights_name is None:
if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE
else:
weights_name = LORA_WEIGHT_NAME

# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))

Expand Down
60 changes: 60 additions & 0 deletions tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import gc
import os
import tempfile
import unittest

Expand Down Expand Up @@ -372,6 +373,65 @@ def test_lora_save_load(self):

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)

with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample

assert (sample - new_sample).abs().max() < 1e-4

# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4

def test_lora_save_load_safetensors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["attention_head_dim"] = (8, 16)

torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)

with torch.no_grad():
old_sample = model(**inputs_dict).sample

lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)

# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1

model.set_attn_processor(lora_attn_procs)

with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
Expand Down

0 comments on commit cdf42b8

Please sign in to comment.