Skip to content

Commit

Permalink
Supporting bfloat16 for tensorflow + jax (was failing because of (#382)
Browse files Browse the repository at this point in the history
intermediary numpy).
  • Loading branch information
Narsil authored Nov 17, 2023
1 parent 96061e9 commit 9e0bc08
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 11 deletions.
9 changes: 6 additions & 3 deletions bindings/python/py_src/safetensors/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import jax.numpy as jnp
from jax import Array
from safetensors import numpy
from safetensors import numpy, safe_open


def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes:
Expand Down Expand Up @@ -122,8 +122,11 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]:
loaded = load_file(file_path)
```
"""
flat = numpy.load_file(filename)
return _np2jnp(flat)
result = {}
with safe_open(filename, framework="flax") as f:
for k in f.keys():
result[k] = f.get_tensor(k)
return result


def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]:
Expand Down
9 changes: 6 additions & 3 deletions bindings/python/py_src/safetensors/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import tensorflow as tf

from safetensors import numpy
from safetensors import numpy, safe_open


def save(tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
Expand Down Expand Up @@ -121,8 +121,11 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]:
loaded = load_file(file_path)
```
"""
flat = numpy.load_file(filename)
return _np2tf(flat)
result = {}
with safe_open(filename, framework="tf") as f:
for k in f.keys():
result[k] = f.get_tensor(k)
return result


def _np2tf(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, tf.Tensor]:
Expand Down
11 changes: 10 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,16 @@ fn get_pydtype(module: &PyModule, dtype: Dtype, is_numpy: bool) -> PyResult<PyOb
let dtype: PyObject = match dtype {
Dtype::F64 => module.getattr(intern!(py, "float64"))?.into(),
Dtype::F32 => module.getattr(intern!(py, "float32"))?.into(),
Dtype::BF16 => module.getattr(intern!(py, "bfloat16"))?.into(),
Dtype::BF16 => {
if is_numpy {
module
.getattr(intern!(py, "dtype"))?
.call1(("bfloat16",))?
.into()
} else {
module.getattr(intern!(py, "bfloat16"))?.into()
}
}
Dtype::F16 => module.getattr(intern!(py, "float16"))?.into(),
Dtype::U64 => module.getattr(intern!(py, "uint64"))?.into(),
Dtype::I64 => module.getattr(intern!(py, "int64"))?.into(),
Expand Down
7 changes: 3 additions & 4 deletions bindings/python/tests/test_flax_comparison.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import platform
import unittest

import numpy as np


if platform.system() != "Windows":
# This platform is not supported, we don't want to crash on import
Expand All @@ -21,6 +19,7 @@ def setUp(self):
"test": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test2": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test3": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test4": jnp.zeros((1024, 1024), dtype=jnp.bfloat16),
}
self.flax_filename = "./tests/data/flax_load.msgpack"
self.sf_filename = "./tests/data/flax_load.safetensors"
Expand Down Expand Up @@ -51,7 +50,7 @@ def test_deserialization_safe(self):

for k, v in weights.items():
tv = flax_weights[k]
self.assertTrue(np.allclose(v, tv))
self.assertTrue(jnp.allclose(v, tv))

def test_deserialization_safe_open(self):
weights = {}
Expand All @@ -65,4 +64,4 @@ def test_deserialization_safe_open(self):

for k, v in weights.items():
tv = flax_weights[k]
self.assertTrue(np.allclose(v, tv))
self.assertTrue(jnp.allclose(v, tv))
14 changes: 14 additions & 0 deletions bindings/python/tests/test_tf_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ def test_deserialization_safe(self):
tv = tf_weights[k]
self.assertTrue(np.allclose(v, tv))

def test_bfloat16(self):
data = {
"test": tf.zeros((1024, 1024), dtype=tf.bfloat16),
}
save_file(data, self.sf_filename)
weights = {}
with safe_open(self.sf_filename, framework="tf") as f:
for k in f.keys():
weights[k] = f.get_tensor(k)

for k, v in weights.items():
tv = data[k]
self.assertTrue(tf.experimental.numpy.allclose(v, tv))

def test_deserialization_safe_open(self):
weights = {}
with safe_open(self.sf_filename, framework="tf") as f:
Expand Down

0 comments on commit 9e0bc08

Please sign in to comment.