Skip to content

Commit

Permalink
Fixing empty serialization (no tensor) with some metadata. (#472)
Browse files Browse the repository at this point in the history
* Fixing empty serialization (no tensor) with some metadata.

* Fixing test value

* Add audit component.

* Install cargo audit.
  • Loading branch information
Narsil authored Apr 24, 2024
1 parent ebf453b commit 079781f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:

- uses: Swatinem/rust-cache@v2

- name: Install cargo-audit
run: cargo install cargo-audit

- name: Install cargo-llvm-cov for Ubuntu
if: matrix.os == 'ubuntu-latest'
run: cargo install cargo-llvm-cov
Expand Down
35 changes: 34 additions & 1 deletion safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,12 @@ impl Serialize for Metadata {
}

let tensors: Vec<_> = names.iter().zip(self.tensors.iter()).collect();
let mut map = serializer.serialize_map(Some(tensors.len()))?;
let length = if let Some(metadata) = &self.metadata {
metadata.len()
} else {
0
};
let mut map = serializer.serialize_map(Some(tensors.len() + length))?;
if let Some(metadata) = &self.metadata {
map.serialize_entry("__metadata__", metadata)?;
}
Expand Down Expand Up @@ -833,6 +838,34 @@ mod tests {
let _parsed = SafeTensors::deserialize(&out).unwrap();
}

#[test]
fn test_empty() {
let tensors: HashMap<String, TensorView> = HashMap::new();

let out = serialize(&tensors, &None).unwrap();
assert_eq!(
out,
[8, 0, 0, 0, 0, 0, 0, 0, 123, 125, 32, 32, 32, 32, 32, 32]
);
let _parsed = SafeTensors::deserialize(&out).unwrap();

let metadata: Option<HashMap<String, String>> = Some(
[("framework".to_string(), "pt".to_string())]
.into_iter()
.collect(),
);
let out = serialize(&tensors, &metadata).unwrap();
assert_eq!(
out,
[
40, 0, 0, 0, 0, 0, 0, 0, 123, 34, 95, 95, 109, 101, 116, 97, 100, 97, 116, 97, 95,
95, 34, 58, 123, 34, 102, 114, 97, 109, 101, 119, 111, 114, 107, 34, 58, 34, 112,
116, 34, 125, 125, 32, 32, 32, 32, 32
]
);
let _parsed = SafeTensors::deserialize(&out).unwrap();
}

#[test]
fn test_serialization_forced_alignement() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
Expand Down

0 comments on commit 079781f

Please sign in to comment.