Skip to content

Commit

Permalink
is_safetensors_compatible refactor (#2499)
Browse files Browse the repository at this point in the history
* is_safetensors_compatible refactor

* files list comma
  • Loading branch information
williamberman authored Mar 1, 2023
1 parent a75ac3f commit 856dad5
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 14 deletions.
56 changes: 42 additions & 14 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput):


def is_safetensors_compatible(filenames, variant=None) -> bool:
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)

for pt_filename in pt_filenames:
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else ""
prefix, raw = os.path.split(pt_filename)
if raw == f"pytorch_model{_variant}.bin":
# transformers specific
sf_filename = os.path.join(prefix, f"model{_variant}.safetensors")
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames = []

sf_filenames = set()

for filename in filenames:
_, extension = os.path.splitext(filename)

if extension == ".bin":
pt_filenames.append(filename)
elif extension == ".safetensors":
sf_filenames.add(filename)

for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)

if filename == "pytorch_model":
filename = "model"
elif filename == f"pytorch_model.{variant}":
filename = f"model.{variant}"
else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if is_safetensors_compatible and sf_filename not in filenames:
logger.warning(f"{sf_filename} not found")
is_safetensors_compatible = False
return is_safetensors_compatible
filename = filename

expected_sf_filename = os.path.join(path, filename)
expected_sf_filename = f"{expected_sf_filename}.safetensors"

if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False

return True


def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
Expand Down
134 changes: 134 additions & 0 deletions tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import unittest

from diffusers.pipelines.pipeline_utils import is_safetensors_compatible


class IsSafetensorsCompatibleTests(unittest.TestCase):
def test_all_is_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_model_is_compatible(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_model_is_not_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
"unet/diffusion_pytorch_model.bin",
# Removed: 'unet/diffusion_pytorch_model.safetensors',
]
self.assertFalse(is_safetensors_compatible(filenames))

def test_transformer_model_is_compatible(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_transformer_model_is_not_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
# Removed: 'text_encoder/model.safetensors',
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))

def test_all_is_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))

def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))

def test_diffusers_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))

def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
"unet/diffusion_pytorch_model.fp16.bin",
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))

def test_transformer_model_is_compatible_variant(self):
filenames = [
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))

def test_transformer_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))

def test_transformer_model_is_not_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))

0 comments on commit 856dad5

Please sign in to comment.