Skip to content

Commit

Permalink
[Tests] [LoRA] clean up the serialization stuff. (#9512)
Browse files Browse the repository at this point in the history
* clean up the serialization stuff.

* better
  • Loading branch information
sayakpaul authored Sep 27, 2024
1 parent 534848c commit 81cf3b2
Showing 1 changed file with 41 additions and 73 deletions.
114 changes: 41 additions & 73 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,32 @@ def get_dummy_tokens(self):
prepared_inputs["input_ids"] = inputs
return prepared_inputs

def _get_lora_state_dicts(self, modules_to_save):
state_dicts = {}
for module_name, module in modules_to_save.items():
if module is not None:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts

def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules

if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
modules_to_save["text_encoder"] = pipe.text_encoder

if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
modules_to_save["text_encoder_2"] = pipe.text_encoder_2

if has_denoiser:
if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
modules_to_save["unet"] = pipe.unet

if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
modules_to_save["transformer"] = pipe.transformer

return modules_to_save

def test_simple_inference(self):
"""
Tests a simple inference and makes sure it works as expected
Expand Down Expand Up @@ -420,45 +446,21 @@ def test_simple_inference_with_text_lora_save_load(self):
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)

self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
)
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)

if self.has_two_text_encoders:
if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()

pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
Expand Down Expand Up @@ -614,54 +616,20 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = (
get_peft_model_state_dict(pipe.text_encoder)
if "text_encoder" in self.pipeline_class._lora_loadable_modules
else None
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)

denoiser_state_dict = get_peft_model_state_dict(denoiser)

saving_kwargs = {
"save_directory": tmpdirname,
"safe_serialization": False,
}

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict})

if self.unet_kwargs is not None:
saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
else:
saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})

self.pipeline_class.save_lora_weights(**saving_kwargs)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()

pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)

self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
Expand Down

0 comments on commit 81cf3b2

Please sign in to comment.