Skip to content

Commit

Permalink
Rebased.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Feb 24, 2023
1 parent 43e2d80 commit c76e417
Show file tree
Hide file tree
Showing 6 changed files with 419 additions and 101 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ cblas-sys = { version = "0.1.4", default-features = false, optional = true }
libc = { version = "0.2", default-features = false, optional = true }
cudarc = { version = "0.7.2", default-features = false, optional = true }
num-traits = { version = "0.2.15", default-features = false }
safetensors = { version = "0.2", default-features = false, optional = true }
# safetensors = { version = "0.2", default-features = false, optional = true }
safetensors = { git = "https://github.com/huggingface/safetensors", default-features = false, optional = true }
memmap2 = { version = "0.5", optional = true }

[features]
default = ["std", "numpy", "fast_alloc"]
std = ["no-std-compat/std", "rand/std", "rand_distr/std", "cudarc?/std", "matrixmultiply/threading"]
fast_alloc = ["std"]
nightly = []
numpy = ["dep:zip", "std"]
safetensors = ["dep:safetensors", "std"]
safetensors = ["dep:safetensors", "std", "dep:memmap2"]
cblas = ["dep:cblas-sys", "dep:libc"]
intel-mkl = ["cblas"]
cuda = ["dep:cudarc"]
Expand Down
45 changes: 33 additions & 12 deletions examples/safetensors-save-load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,46 @@
#[cfg(feature = "safetensors")]
fn main() {
use dfdx::{
nn::{DeviceBuildExt, LoadFromSafeTensors, SaveToSafeTensors},
prelude::Linear,
shapes::{Rank0, Rank1, Rank2},
tensor::safetensors::Writer,
tensor::{AsArray, Cpu, Tensor, TensorFromArray, ZerosTensor},
tensor::safetensors::SafeWriter,
tensor::{AsArray, Tensor, TensorFrom, ZerosTensor},
};
use safetensors::tensor::SafeTensors;
let dev: Cpu = Default::default();
#[cfg(not(feature = "cuda"))]
type Device = dfdx::tensor::Cpu;

#[cfg(feature = "cuda")]
type Device = dfdx::tensor::Cuda;

let dev: Device = Default::default();

type Model = Linear<4, 2>;
let m = dev.build_module::<Model, f32>();

m.save_safetensors("linear.safetensors")
.expect("Failed to write");

let mut m2 = dev.build_module::<Model, f32>();
assert_ne!(m.weight.array(), m2.weight.array());
assert_ne!(m.bias.array(), m2.bias.array());
m2.load_safetensors("linear.safetensors")
.expect("Failed to load");
assert_eq!(m.weight.array(), m2.weight.array());
assert_eq!(m.bias.array(), m2.bias.array());

let a = dev.tensor(1.234f32);
let b = dev.tensor([1.0f32, 2.0, 3.0]);
let c = dev.tensor([[1.0f32, 2.0, 3.0], [-1.0, -2.0, -3.0]]);

let path = std::path::Path::new("out.safetensors");

Writer::new()
.add("a".to_string(), a)
.add("b".to_string(), b)
.add("c".to_string(), c)
.save(path)
.unwrap();
let mut w = SafeWriter::new();
w.add("a".to_string(), &a);
w.add("b".to_string(), &b);
w.add("c".to_string(), &c);
w.save_safetensors(path).unwrap();

let mut a: Tensor<Rank0, f32, _> = dev.zeros();
let mut b: Tensor<Rank1<3>, f32, _> = dev.zeros();
Expand All @@ -30,9 +51,9 @@ fn main() {
let filename = "out.safetensors";
let buffer = std::fs::read(filename).expect("Couldn't read file");
let tensors = SafeTensors::deserialize(&buffer).expect("Couldn't read safetensors file");
a.load(&tensors, "a").expect("Loading a failed");
b.load(&tensors, "b").expect("Loading b failed");
c.load(&tensors, "c").expect("Loading c failed");
a.load_safetensors(&tensors, "a").expect("Loading a failed");
b.load_safetensors(&tensors, "b").expect("Loading b failed");
c.load_safetensors(&tensors, "c").expect("Loading c failed");

assert_eq!(a.array(), 1.234);
assert_eq!(b.array(), [1.0, 2.0, 3.0]);
Expand Down
4 changes: 4 additions & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,15 @@ mod pool2d;
mod pool_global;
mod repeated;
mod residual;
#[cfg(feature = "safetensors")]
mod safetensors;
mod split_into;
mod transformer;

pub use module::*;

#[cfg(feature = "safetensors")]
pub use crate::nn::safetensors::{LoadFromSafeTensors, SaveToSafeTensors};
#[cfg(feature = "numpy")]
pub use npz::{LoadFromNpz, SaveToNpz};
pub use num_params::NumParams;
Expand Down
292 changes: 292 additions & 0 deletions src/nn/safetensors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
use crate::{
shapes::{Dtype, Shape},
tensor::{
safetensors::{SafeDtype, SafeWriter},
CopySlice, Tensor,
},
};
use memmap2::MmapOptions;
use safetensors::tensor::{SafeTensorError, SafeTensors};

use super::tensor_collection::*;

use std::{path::Path, string::String};

/// Something that can be saved to a `.safetensors` (which is a `.zip`).
///
/// All [super::Module]s in nn implement SaveToSafeTensors, and the zips are formatted in a `.safetensors` fashion.
pub trait SaveToSafeTensors<E: Dtype + SafeDtype, D: CopySlice<E>>: TensorCollection<E, D> {
fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<(), SafeTensorError> {
let mut writer = SafeWriter::new();
self.write_safetensors(&mut writer)?;
writer.save_safetensors(path.as_ref())?;
Ok(())
}

fn write_safetensors(&self, w: &mut SafeWriter) -> Result<(), SafeTensorError> {
Self::iter_tensors(&mut RecursiveWalker {
m: self,
f: w,
path: &mut std::vec::Vec::new(),
})
}
}
impl<E: Dtype + SafeDtype, D: CopySlice<E>, T: TensorCollection<E, D>> SaveToSafeTensors<E, D>
for T
{
}

/// Something that can be loaded from a `.safetensors` file.
///
/// All [super::Module]s in nn implement LoadFromSafeTensors, and the zips are formatted in a `.safetensors` fashion.
pub trait LoadFromSafeTensors<E: Dtype + SafeDtype, D: CopySlice<E>>:
TensorCollection<E, D>
{
/// Loads data from a `.safetensors` zip archive at the specified `path`.
///
/// Example:
/// ```ignore
/// # use dfdx::prelude::*;
/// let mut model: (Linear<5, 10>, Linear<10, 5>) = Default::default();
/// model.load("tst.safetensors")?;
/// ```
fn load_safetensors<P: AsRef<Path>>(&mut self, path: P) -> Result<(), SafeTensorError> {
let file = std::fs::File::open(path)?;
let buffer = unsafe { MmapOptions::new().map(&file)? };
let mut tensors = SafeTensors::deserialize(&buffer)?;
self.read_safetensors(&mut tensors)?;
Ok(())
}

fn read_safetensors<'data>(
&mut self,
tensors: &mut SafeTensors<'data>,
) -> Result<(), SafeTensorError> {
Self::iter_tensors(&mut RecursiveWalker {
m: self,
f: tensors,
path: &mut std::vec::Vec::new(),
})
}
}
impl<E: Dtype + SafeDtype, D: CopySlice<E>, T: TensorCollection<E, D>> LoadFromSafeTensors<E, D>
for T
{
}

impl<E: Dtype + SafeDtype, D: CopySlice<E>> TensorVisitor<E, D> for SafeWriter {
type Viewer = ViewTensorRef;
type Err = SafeTensorError;

fn visit<S: Shape>(
&mut self,
full_path: String,
_: TensorOptions<S, E, D>,
t: &Tensor<S, E, D>,
) -> Result<(), Self::Err> {
self.add(full_path, t);
Ok(())
}
}

impl<'data, E: Dtype + SafeDtype, D: CopySlice<E>> TensorVisitor<E, D> for SafeTensors<'data> {
type Viewer = ViewTensorMut;
type Err = SafeTensorError;

fn visit<S: Shape>(
&mut self,
full_path: String,
_: TensorOptions<S, E, D>,
t: &mut Tensor<S, E, D>,
) -> Result<(), Self::Err> {
t.load_safetensors(self, &full_path)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
nn::{builders::*, *},
shapes::*,
tensor::{safetensors::SafeDtype, AsArray, SampleTensor, Tensor},
tensor_ops::Device,
tests::{TestDevice, TestDtype},
};
use rand_distr::{Distribution, Standard, StandardNormal};
use tempfile::NamedTempFile;

fn test_save_load<S: ConstShape, E: Dtype + SafeDtype, D: Device<E>, M: BuildOnDevice<D, E>>(
dev: &D,
) where
M::Built: Module<Tensor<S, E, D>> + SaveToSafeTensors<E, D> + LoadFromSafeTensors<E, D>,
<M::Built as Module<Tensor<S, E, D>>>::Output: AsArray,
StandardNormal: Distribution<E>,
{
let x = dev.sample_normal();
let file = NamedTempFile::new().expect("failed to create tempfile");

let saved: M::Built = M::build_on_device(dev);
let mut loaded: M::Built = M::build_on_device(dev);

let y = saved.forward(x.clone());

assert_ne!(loaded.forward(x.clone()).array(), y.array());

saved.save_safetensors(file.path()).expect("");
loaded.load_safetensors(file.path()).expect("");

assert_eq!(loaded.forward(x).array(), y.array());
}

#[test]
fn test_batchnorm2d_save_load() {
let dev: TestDevice = Default::default();
type Model = BatchNorm2D<3>;

let x: Tensor<Rank3<3, 4, 5>, TestDtype, _> = dev.sample_normal();
let file = NamedTempFile::new().expect("failed to create tempfile");

let mut saved = Model::build_on_device(&dev);
let mut loaded = Model::build_on_device(&dev);

saved.running_mean.fill_with_distr(Standard);
saved.running_var.fill_with_distr(Standard);
saved.scale.fill_with_distr(Standard);
saved.bias.fill_with_distr(Standard);
let y = saved.forward(x.clone());

assert_ne!(loaded.forward(x.clone()).array(), y.array());

saved.save(file.path()).expect("");
loaded.load(file.path()).expect("");

assert_eq!(loaded.forward(x).array(), y.array());
}

#[cfg(feature = "nightly")]
#[test]
fn test_save_load_conv() {
type T = Conv2D<2, 4, 3>;
let dev: TestDevice = Default::default();
test_save_load::<Rank3<2, 8, 8>, TestDtype, TestDevice, T>(&dev);
}

#[test]
fn test_save_load_generalized_residual() {
let dev: TestDevice = Default::default();
type T = GeneralizedResidual<Linear<5, 5>, Linear<5, 5>>;
test_save_load::<Rank1<5>, TestDtype, TestDevice, T>(&dev);
test_save_load::<Rank1<5>, TestDtype, TestDevice, (T, T)>(&dev);
}

#[test]
fn test_save_load_linear() {
let dev: TestDevice = Default::default();
type T = Linear<5, 5>;
test_save_load::<Rank1<5>, TestDtype, TestDevice, T>(&dev);
test_save_load::<Rank1<5>, TestDtype, TestDevice, (T, T)>(&dev);
}

#[test]
fn test_save_load_tuple() {
let dev: TestDevice = Default::default();
type T = (
(Linear<1, 2>, ReLU, Linear<2, 3>),
(Dropout, Linear<3, 3>, Linear<3, 4>),
);
test_save_load::<Rank1<1>, TestDtype, TestDevice, T>(&dev);
}

#[test]
fn test_save_load_layer_norm() {
type M = LayerNorm1D<3>;
let dev: TestDevice = Default::default();
let x: Tensor<Rank1<3>, TestDtype, _> = dev.sample_normal();

let file = NamedTempFile::new().expect("failed to create tempfile");

let mut saved = M::build_on_device(&dev);
let mut loaded = M::build_on_device(&dev);

saved.gamma.fill_with_distr(Standard);
saved.beta.fill_with_distr(Standard);
let y = saved.forward(x.clone());

assert_ne!(loaded.forward(x.clone()).array(), y.array());

saved.save(file.path()).expect("");
loaded.load(file.path()).expect("");

assert_eq!(loaded.forward(x).array(), y.array());
}

#[test]
fn test_save_load_repeated() {
type T = Repeated<Linear<3, 3>, 4>;
let dev: TestDevice = Default::default();
test_save_load::<Rank1<3>, TestDtype, TestDevice, T>(&dev);
test_save_load::<Rank1<3>, TestDtype, TestDevice, (T, T)>(&dev);
}

#[test]
fn test_save_load_residual() {
type T = Residual<Linear<5, 5>>;
let dev: TestDevice = Default::default();
test_save_load::<Rank1<5>, TestDtype, TestDevice, T>(&dev);
test_save_load::<Rank1<5>, TestDtype, TestDevice, (T, T)>(&dev);
}

#[cfg(feature = "nightly")]
#[test]
fn test_save_load_mha() {
let dev: TestDevice = Default::default();
type Model = MultiHeadAttention<12, 4>;

let saved = Model::build_on_device(&dev);

let file = NamedTempFile::new().expect("failed to create tempfile");
saved.save(file.path()).expect("");

let mut loaded = Model::build_on_device(&dev);

let q: Tensor<Rank3<2, 3, 12>, TestDtype, _> = dev.sample_normal();
let k: Tensor<Rank3<2, 4, 12>, TestDtype, _> = dev.sample_normal();
let v: Tensor<Rank3<2, 4, 12>, TestDtype, _> = dev.sample_normal();
let y1 = saved.forward((q.clone(), k.clone(), v.clone()));

let y2 = loaded.forward((q.clone(), k.clone(), v.clone()));
assert_ne!(y1.array(), y2.array());

loaded.load(file.path()).expect("");

let y2 = loaded.forward((q.clone(), k.clone(), v.clone()));
assert_eq!(y1.array(), y2.array());
}

#[cfg(feature = "nightly")]
#[test]
fn test_save_load_transformer() {
let dev: TestDevice = Default::default();
type Model = Transformer<16, 4, 3, 4, 8>;

let mut saved = Model::build_on_device(&dev);

let file = NamedTempFile::new().expect("failed to create tempfile");
saved.save(file.path()).expect("");

let mut loaded = Model::build_on_device(&dev);

let src: Tensor<Rank3<4, 12, 16>, TestDtype, _> = dev.sample_normal();
let tgt: Tensor<Rank3<4, 6, 16>, TestDtype, _> = dev.sample_normal();
let y1 = saved.forward_mut((src.clone(), tgt.clone()));

let y2 = loaded.forward_mut((src.clone(), tgt.clone()));
assert_ne!(y1.array(), y2.array());

loaded.load(file.path()).expect("");

let y2 = loaded.forward_mut((src.clone(), tgt.clone()));
assert_eq!(y1.array(), y2.array());
}
}
Loading

0 comments on commit c76e417

Please sign in to comment.