Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safetensors support. #381

Merged
merged 6 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cargo-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --features test-f64
args: --features test-f64,safetensors
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ cblas-sys = { version = "0.1.4", default-features = false, optional = true }
libc = { version = "0.2", default-features = false, optional = true }
cudarc = { version = "0.8.0", default-features = false, optional = true }
num-traits = { version = "0.2.15", default-features = false }
safetensors = { version = "0.3", default-features = false, optional = true }
memmap2 = { version = "0.5", default-features = false, optional = true }

[dev-dependencies]
tempfile = "3.3.0"
Expand All @@ -50,6 +52,7 @@ threaded-cpu = ["std", "matrixmultiply/threading"]
fast-alloc = ["std"]
nightly = []
numpy = ["dep:zip", "std"]
safetensors = ["dep:safetensors", "std", "dep:memmap2"]
cblas = ["dep:cblas-sys", "dep:libc"]
intel-mkl = ["cblas"]
cuda = ["dep:cudarc", "dep:glob"]
Expand All @@ -71,4 +74,4 @@ harness = false

[[bench]]
name = "softmax"
harness = false
harness = false
44 changes: 44 additions & 0 deletions examples/safetensors-save-load.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//! Demonstrates how to save and load arrays with safetensors

#[cfg(feature = "safetensors")]
fn main() {
use ::safetensors::SafeTensors;
use dfdx::{
prelude::*,
tensor::{AsArray, Cpu},
};
use memmap2::MmapOptions;
let dev: Cpu = Default::default();

type Model = (Linear<5, 10>, Linear<10, 5>);
let model = dev.build_module::<Model, f32>();
model
.save_safetensors("model.safetensors")
.expect("Failed to save model");

let mut model2 = dev.build_module::<Model, f32>();
model2
.load_safetensors("model.safetensors")
.expect("Failed to load model");

assert_eq!(model.0.weight.array(), model2.0.weight.array());

// ADVANCED USAGE to load pre-existing models

// wget -O gpt2.safetensors https://huggingface.co/gpt2/resolve/main/model.safetensors

let mut gpt2 = dev.build_module::<Linear<728, 50257>, f32>();
let filename = "gpt2.safetensors";
let f = std::fs::File::open(filename).expect("Couldn't read file, have you downloaded gpt2 ? `wget -O gpt2.safetensors https://huggingface.co/gpt2/resolve/main/model.safetensors`");
let buffer = unsafe { MmapOptions::new().map(&f).expect("Could not mmap") };
let tensors = SafeTensors::deserialize(&buffer).expect("Couldn't read safetensors file");

gpt2.weight
.load_safetensor(&tensors, "wte.weight")
.expect("Could not load tensor");
Comment on lines +36 to +38
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome example

}

#[cfg(not(feature = "safetensors"))]
fn main() {
panic!("Use the 'safetensors' feature to run this example");
}
21 changes: 21 additions & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@
//! state_dict = {k: torch.from_numpy(v) for k, v in np.load("dfdx-model.npz").items()}
//! mlp.load_state_dict(state_dict)
//! ```
//!
//! The feature `safetensors` allows to do the same with
//! [https://github.com/huggingface/safetensors]()
//! Call [SaveToSafetensors::save()] and [LoadFromSafetensors::load()] traits. All modules provided here implement it,
//! including tuples. These all save to/from `.safetensors` files, which are flat layout with JSON
//! header, allowing for super fast loads (with memory mapping).
//!
//! This is implemented to be fairly portable. For example you can use
//! [https://github.com/huggingface/transformers]()
//!
//! ```python
//! from transformers import pipeline
//!
//! pipe = pipeline(model="gpt2")
//! pipe.save_pretrained("my_local", safe_serialization=True)
//! # This created `my_local/model.safetensors` file which can now be used.
//! ```

mod num_params;
mod reset_params;
Expand All @@ -128,13 +145,17 @@ mod pool2d;
mod pool_global;
mod repeated;
mod residual;
#[cfg(feature = "safetensors")]
mod safetensors;
mod split_into;
mod transformer;
mod unbiased_linear;
mod zero_grads;

pub use module::*;

#[cfg(feature = "safetensors")]
pub use crate::nn::safetensors::{LoadFromSafetensors, SaveToSafetensors};
pub use ema::ModelEMA;
#[cfg(feature = "numpy")]
pub use npz::{LoadFromNpz, SaveToNpz};
Expand Down
Loading