diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index beb2b9ef4e7e..cfc7c5896f97 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -19,13 +19,18 @@ from .models.cross_attention import LoRACrossAttnProcessor from .models.modeling_utils import _get_model_file -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging +from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging + + +if is_safetensors_available(): + import safetensors logger = logging.get_logger(__name__) LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" class AttnProcsLayers(torch.nn.Module): @@ -136,28 +141,53 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME) + weight_name = kwargs.pop("weight_name", None) user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } + model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") + if is_safetensors_available(): + if weight_name is None: + weight_name = LORA_WEIGHT_NAME_SAFE + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except EnvironmentError: + if weight_name == LORA_WEIGHT_NAME_SAFE: + weight_name = None + if model_file is None: + if weight_name is None: + weight_name = LORA_WEIGHT_NAME + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict @@ -195,8 +225,9 @@ def save_attn_procs( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, - weights_name: str = LORA_WEIGHT_NAME, + weights_name: str = None, save_function: Callable = None, + safe_serialization: bool = False, ): r""" Save an attention processor to a directory, so that it can be re-loaded using the @@ -219,7 +250,13 @@ def save_attn_procs( return if save_function is None: - save_function = torch.save + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save os.makedirs(save_directory, exist_ok=True) @@ -237,6 +274,12 @@ def save_attn_procs( if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process: os.remove(full_filename) + if weights_name is None: + if safe_serialization: + weights_name = LORA_WEIGHT_NAME_SAFE + else: + weights_name = LORA_WEIGHT_NAME + # Save the model save_function(state_dict, os.path.join(save_directory, weights_name)) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 9909bf3f29ab..bc025f6eeb56 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -14,6 +14,7 @@ # limitations under the License. import gc +import os import tempfile import unittest @@ -372,6 +373,65 @@ def test_lora_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.to(torch_device) + new_model.load_attn_procs(tmpdirname) + + with torch.no_grad(): + new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample - new_sample).abs().max() < 1e-4 + + # LoRA and no LoRA should NOT be the same + assert (sample - old_sample).abs().max() > 1e-4 + + def test_lora_save_load_safetensors(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + lora_attn_procs = {} + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + lora_attn_procs[name] = lora_attn_procs[name].to(model.device) + + # add 1 to weights to mock trained weights + with torch.no_grad(): + lora_attn_procs[name].to_q_lora.up.weight += 1 + lora_attn_procs[name].to_k_lora.up.weight += 1 + lora_attn_procs[name].to_v_lora.up.weight += 1 + lora_attn_procs[name].to_out_lora.up.weight += 1 + + model.set_attn_processor(lora_attn_procs) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname, safe_serialization=True) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) new_model.to(torch_device)