diff --git a/README.md b/README.md index 5bc74b23..4fd32c4f 100644 --- a/README.md +++ b/README.md @@ -88,20 +88,20 @@ with safe_open("model.safetensors", framework="pt", device="cpu") as f: Notes: - Duplicate keys are disallowed. Not all parsers may respect this. - In general the subset of JSON is implicitly decided by `serde_json` for - this library. Anything obscure might be modified at a later time, that odd ways - to represent integer, newlines and escapes in utf-8 strings. This would only - be done for safety concerns + this library. Anything obscure might be modified at a later time, that odd ways + to represent integer, newlines and escapes in utf-8 strings. This would only + be done for safety concerns - Tensor values are not checked against, in particular NaN and +/-Inf could - be in the file + be in the file - Empty tensors (tensors with 1 dimension being 0) are allowed. - They are not storing any data in the databuffer, yet retaining size in the header. - They don't really bring a lot of values but are accepted since they are valid tensors - from traditional tensor libraries perspective (torch, tensorflow, numpy, ..). + They are not storing any data in the databuffer, yet retaining size in the header. + They don't really bring a lot of values but are accepted since they are valid tensors + from traditional tensor libraries perspective (torch, tensorflow, numpy, ..). - 0-rank Tensors (tensors with shape `[]`) are allowed, they are merely a scalar. - The byte buffer needs to be entirely indexed, and cannot contain holes. This prevents -the creation of polyglot files. + the creation of polyglot files. - Endianness: Little-endian. - moment. + moment. - Order: 'C' or row-major. @@ -132,12 +132,12 @@ This is my very personal and probably biased view: - Safe: Can I use a file randomly downloaded and expect not to run arbitrary code ? - Zero-copy: Does reading the file require more memory than the original file ? - Lazy loading: Can I inspect the file without loading everything ? And loading only -some tensors in it without scanning the whole file (distributed setting) ? + some tensors in it without scanning the whole file (distributed setting) ? - Layout control: Lazy loading, is not necessarily enough since if the information about tensors is spread out in your file, then even if the information is lazily accessible you might have to access most of your file to read the available tensors (incurring many DISK -> RAM copies). Controlling the layout to keep fast access to single tensors is important. - No file size limit: Is there a limit to the file size ? - Flexibility: Can I save custom code in the format and be able to use it later with zero extra code ? (~ means we can store more than pure tensors, but no custom code) - Bfloat16/Fp8: Does the format support native bfloat16/fp8 (meaning no weird workarounds are -necessary)? This is becoming increasingly important in the ML world. + necessary)? This is becoming increasingly important in the ML world. ### Main oppositions @@ -154,12 +154,12 @@ necessary)? This is becoming increasingly important in the ML world. ### Notes - Zero-copy: No format is really zero-copy in ML, it needs to go from disk to RAM/GPU RAM (that takes time). On CPU, if the file is already in cache, then it can -truly be zero-copy, whereas on GPU there is not such disk cache, so a copy is always required -but you can bypass allocating all the tensors on CPU at any given point. - SafeTensors is not zero-copy for the header. The choice of JSON is pretty arbitrary, but since deserialization is <<< of the time required to load the actual tensor data and is readable I went that way, (also space is <<< to the tensor data). + truly be zero-copy, whereas on GPU there is not such disk cache, so a copy is always required + but you can bypass allocating all the tensors on CPU at any given point. + SafeTensors is not zero-copy for the header. The choice of JSON is pretty arbitrary, but since deserialization is <<< of the time required to load the actual tensor data and is readable I went that way, (also space is <<< to the tensor data). - Endianness: Little-endian. This can be modified later, but it feels really unnecessary at the -moment. + moment. - Order: 'C' or row-major. This seems to have won. We can add that information later if needed. - Stride: No striding, all tensors need to be packed before being serialized. I have yet to see a case where it seems useful to have a strided tensor stored in serialized format. @@ -168,26 +168,26 @@ moment. Since we can invent a new format we can propose additional benefits: - Prevent DOS attacks: We can craft the format in such a way that it's almost -impossible to use malicious files to DOS attack a user. Currently, there's a limit -on the size of the header of 100MB to prevent parsing extremely large JSON. - Also when reading the file, there's a guarantee that addresses in the file - do not overlap in any way, meaning when you're loading a file you should never - exceed the size of the file in memory + impossible to use malicious files to DOS attack a user. Currently, there's a limit + on the size of the header of 100MB to prevent parsing extremely large JSON. + Also when reading the file, there's a guarantee that addresses in the file + do not overlap in any way, meaning when you're loading a file you should never + exceed the size of the file in memory - Faster load: PyTorch seems to be the fastest file to load out in the major -ML formats. However, it does seem to have an extra copy on CPU, which we -can bypass in this lib by using `torch.UntypedStorage.from_file`. -Currently, CPU loading times are extremely fast with this lib compared to pickle. -GPU loading times are as fast or faster than PyTorch equivalent. -Loading first on CPU with memmapping with torch, and then moving all tensors to GPU seems -to be faster too somehow (similar behavior in torch pickle) + ML formats. However, it does seem to have an extra copy on CPU, which we + can bypass in this lib by using `torch.UntypedStorage.from_file`. + Currently, CPU loading times are extremely fast with this lib compared to pickle. + GPU loading times are as fast or faster than PyTorch equivalent. + Loading first on CPU with memmapping with torch, and then moving all tensors to GPU seems + to be faster too somehow (similar behavior in torch pickle) - Lazy loading: in distributed (multi-node or multi-gpu) settings, it's nice to be able to -load only part of the tensors on the various models. For -[BLOOM](https://huggingface.co/bigscience/bloom) using this format enabled -to load the model on 8 GPUs from 10mn with regular PyTorch weights down to 45s. -This really speeds up feedbacks loops when developing on the model. For instance -you don't have to have separate copies of the weights when changing the distribution -strategy (for instance Pipeline Parallelism vs Tensor Parallelism). + load only part of the tensors on the various models. For + [BLOOM](https://huggingface.co/bigscience/bloom) using this format enabled + to load the model on 8 GPUs from 10mn with regular PyTorch weights down to 45s. + This really speeds up feedbacks loops when developing on the model. For instance + you don't have to have separate copies of the weights when changing the distribution + strategy (for instance Pipeline Parallelism vs Tensor Parallelism). License: Apache-2.0 diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 07e6ab93..71a9c839 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -9,7 +9,7 @@ name = "safetensors_rust" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.21.1", features = ["extension-module"] } +pyo3 = { version = "0.22" } memmap2 = "0.9" serde_json = "1.0" diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index d0e09c61..e0a52524 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -2,19 +2,21 @@ //! Dummy doc use memmap2::{Mmap, MmapOptions}; use pyo3::exceptions::{PyException, PyFileNotFoundError}; -use pyo3::sync::GILOnceCell; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; use pyo3::types::IntoPyDict; use pyo3::types::PySlice; use pyo3::types::{PyByteArray, PyBytes, PyDict, PyList}; +use pyo3::Bound as PyBound; use pyo3::{intern, PyErr}; use safetensors::slice::TensorIndexer; use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorInfo, TensorView}; +use safetensors::View; +use std::borrow::Cow; use std::collections::HashMap; use std::fs::File; use std::iter::FromIterator; use std::ops::Bound; -use pyo3::{Bound as PyBound}; use std::path::PathBuf; use std::sync::Arc; @@ -24,57 +26,76 @@ static TENSORFLOW_MODULE: GILOnceCell> = GILOnceCell::new(); static FLAX_MODULE: GILOnceCell> = GILOnceCell::new(); static MLX_MODULE: GILOnceCell> = GILOnceCell::new(); -fn prepare(tensor_dict: HashMap) -> PyResult>> { +struct PyView<'a> { + shape: Vec, + dtype: Dtype, + data: PyBound<'a, PyBytes>, + data_len: usize, +} + +impl<'a> View for &PyView<'a> { + fn data(&self) -> std::borrow::Cow<[u8]> { + Cow::Borrowed(self.data.as_bytes()) + } + fn shape(&self) -> &[usize] { + &self.shape + } + fn dtype(&self) -> Dtype { + self.dtype + } + fn data_len(&self) -> usize { + self.data_len + } +} + +fn prepare(tensor_dict: HashMap>) -> PyResult> { let mut tensors = HashMap::with_capacity(tensor_dict.len()); - for (tensor_name, tensor_desc) in tensor_dict { - let mut shape: Option> = None; - let mut dtype: Option = None; - let mut data: Option<&[u8]> = None; - for (key, value) in tensor_desc { - let key: &str = key.extract()?; - match key { - "shape" => shape = value.extract()?, - "dtype" => { - let value: &str = value.extract()?; - dtype = match value { - "bool" => Some(Dtype::BOOL), - "int8" => Some(Dtype::I8), - "uint8" => Some(Dtype::U8), - "int16" => Some(Dtype::I16), - "uint16" => Some(Dtype::U16), - "int32" => Some(Dtype::I32), - "uint32" => Some(Dtype::U32), - "int64" => Some(Dtype::I64), - "uint64" => Some(Dtype::U64), - "float16" => Some(Dtype::F16), - "float32" => Some(Dtype::F32), - "float64" => Some(Dtype::F64), - "bfloat16" => Some(Dtype::BF16), - "float8_e4m3fn" => Some(Dtype::F8_E4M3), - "float8_e5m2" => Some(Dtype::F8_E5M2), - dtype_str => { - return Err(SafetensorError::new_err(format!( - "dtype {dtype_str} is not covered", - ))); - } - } - } - "data" => data = Some(value.extract::<&[u8]>()?), - _ => println!("Ignored unknown kwarg option {key}"), - }; - } - let shape = shape.ok_or_else(|| { - SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")) + for (tensor_name, tensor_desc) in &tensor_dict { + let shape: Vec = tensor_desc + .get_item("shape")? + .ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")))? + .extract()?; + let pydata: PyBound = tensor_desc.get_item("data")?.ok_or_else(|| { + SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}")) })?; - let dtype = dtype.ok_or_else(|| { + // Make sure it's extractable first. + let data: &[u8] = pydata.extract()?; + let data_len = data.len(); + let data: PyBound = pydata.extract()?; + let pydtype = tensor_desc.get_item("dtype")?.ok_or_else(|| { SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}")) })?; - let data = data.ok_or_else(|| { - SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}")) - })?; - let tensor = TensorView::new(dtype, shape, data) - .map_err(|e| SafetensorError::new_err(format!("Error preparing tensor view: {e:?}")))?; - tensors.insert(tensor_name, tensor); + let dtype: &str = pydtype.extract()?; + let dtype = match dtype { + "bool" => Dtype::BOOL, + "int8" => Dtype::I8, + "uint8" => Dtype::U8, + "int16" => Dtype::I16, + "uint16" => Dtype::U16, + "int32" => Dtype::I32, + "uint32" => Dtype::U32, + "int64" => Dtype::I64, + "uint64" => Dtype::U64, + "float16" => Dtype::F16, + "float32" => Dtype::F32, + "float64" => Dtype::F64, + "bfloat16" => Dtype::BF16, + "float8_e4m3fn" => Dtype::F8_E4M3, + "float8_e5m2" => Dtype::F8_E5M2, + dtype_str => { + return Err(SafetensorError::new_err(format!( + "dtype {dtype_str} is not covered", + ))); + } + }; + + let tensor = PyView { + shape, + dtype, + data, + data_len, + }; + tensors.insert(tensor_name.to_string(), tensor); } Ok(tensors) } @@ -92,10 +113,10 @@ fn prepare(tensor_dict: HashMap) -> PyResult( py: Python<'b>, - tensor_dict: HashMap, + tensor_dict: HashMap>, metadata: Option>, ) -> PyResult> { let tensors = prepare(tensor_dict)?; @@ -121,9 +142,9 @@ fn serialize<'b>( /// (`bytes`): /// The serialized content. #[pyfunction] -#[pyo3(text_signature = "(tensor_dict, filename, metadata=None)")] +#[pyo3(signature = (tensor_dict, filename, metadata=None))] fn serialize_file( - tensor_dict: HashMap, + tensor_dict: HashMap>, filename: PathBuf, metadata: Option>, ) -> PyResult<()> { @@ -144,7 +165,7 @@ fn serialize_file( /// The deserialized content is like: /// [("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)] #[pyfunction] -#[pyo3(text_signature = "(bytes)")] +#[pyo3(signature = (bytes))] fn deserialize(py: Python, bytes: &[u8]) -> PyResult)>> { let safetensor = SafeTensors::deserialize(bytes) .map_err(|e| SafetensorError::new_err(format!("Error while deserializing: {e:?}")))?; @@ -215,7 +236,7 @@ enum Framework { } impl<'source> FromPyObject<'source> for Framework { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &PyBound<'source, PyAny>) -> PyResult { let name: String = ob.extract()?; match &name[..] { "pt" => Ok(Framework::Pytorch), @@ -248,7 +269,7 @@ enum Device { } impl<'source> FromPyObject<'source> for Device { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &PyBound<'source, PyAny>) -> PyResult { if let Ok(name) = ob.extract::() { match &name[..] { "cpu" => Ok(Device::Cpu), @@ -437,7 +458,8 @@ impl Open { // .getattr(intern!(py, "from_file"))? .call_method("from_file", (py_filename,), Some(&kwargs))?; - let untyped: PyBound<'_, PyAny> = match storage.getattr(intern!(py, "untyped")) { + let untyped: PyBound<'_, PyAny> = match storage.getattr(intern!(py, "untyped")) + { Ok(untyped) => untyped, Err(_) => storage.getattr(intern!(py, "_untyped"))?, }; @@ -515,7 +537,8 @@ impl Open { let data = &mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset]; - let array: PyObject = Python::with_gil(|py| PyByteArray::new_bound(py, data).into_py(py)); + let array: PyObject = + Python::with_gil(|py| PyByteArray::new_bound(py, data).into_py(py)); create_tensor( &self.framework, @@ -586,8 +609,7 @@ impl Open { if self.device != Device::Cpu { let device: PyObject = self.device.clone().into_py(py); let kwargs = PyDict::new_bound(py); - tensor = tensor - .call_method("to", (device,), Some(&kwargs))?; + tensor = tensor.call_method("to", (device,), Some(&kwargs))?; } Ok(tensor.into_py(py)) // torch.asarray(storage[start + n : stop + n], dtype=torch.uint8).view(dtype=dtype).reshape(shape) @@ -661,7 +683,7 @@ impl safe_open { #[pymethods] impl safe_open { #[new] - #[pyo3(text_signature = "(self, filename, framework, device=\"cpu\")")] + #[pyo3(signature = (filename, framework, device=Some(Device::Cpu)))] fn new(filename: PathBuf, framework: Framework, device: Option) -> PyResult { let inner = Some(Open::new(filename, framework, device)?); Ok(Self { inner }) @@ -748,7 +770,7 @@ struct PySafeSlice { #[derive(FromPyObject)] enum SliceIndex<'a> { - Slice(&'a PySlice), + Slice(PyBound<'a, PySlice>), Index(i32), } @@ -923,8 +945,7 @@ impl PySafeSlice { if self.device != Device::Cpu { let device: PyObject = self.device.clone().into_py(py); let kwargs = PyDict::new_bound(py); - tensor = tensor - .call_method("to", (device,), Some(&kwargs))?; + tensor = tensor.call_method("to", (device,), Some(&kwargs))?; } Ok(tensor.into_py(py)) }), @@ -974,12 +995,12 @@ fn create_tensor<'a>( let dtype: PyObject = get_pydtype(module, dtype, is_numpy)?; let count: usize = shape.iter().product(); let shape = shape.to_vec(); - let shape: PyObject = shape.into_py(py); let tensor = if count == 0 { // Torch==1.10 does not allow frombuffer on empty buffers so we create // the tensor manually. // let zeros = module.getattr(intern!(py, "zeros"))?; - let args = (shape.clone(),); + let shape: PyObject = shape.clone().into_py(py); + let args = (shape,); let kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py); module.call_method("zeros", args, Some(&kwargs))? } else { @@ -1029,8 +1050,7 @@ fn create_tensor<'a>( if device != &Device::Cpu { let device: PyObject = device.clone().into_py(py); let kwargs = PyDict::new_bound(py); - tensor = tensor - .call_method("to", (device,), Some(&kwargs))?; + tensor = tensor.call_method("to", (device,), Some(&kwargs))?; } tensor } @@ -1067,7 +1087,9 @@ fn get_pydtype(module: &PyBound<'_, PyModule>, dtype: Dtype, is_numpy: bool) -> Dtype::I8 => module.getattr(intern!(py, "int8"))?.into(), Dtype::BOOL => { if is_numpy { - py.import_bound("builtins")?.getattr(intern!(py, "bool"))?.into() + py.import_bound("builtins")? + .getattr(intern!(py, "bool"))? + .into() } else { module.getattr(intern!(py, "bool"))?.into() } @@ -1098,7 +1120,10 @@ fn _safetensors_rust(m: &PyBound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(serialize_file, m)?)?; m.add_function(wrap_pyfunction!(deserialize, m)?)?; m.add_class::()?; - m.add("SafetensorError", m.py().get_type_bound::())?; + m.add( + "SafetensorError", + m.py().get_type_bound::(), + )?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; Ok(()) }