Skip to content

Commit

Permalink
Merge pull request #108 from robertknight/pool-alloc
Browse files Browse the repository at this point in the history
Add buffer pool/arena to enable re-use of temporary buffers during graph execution
  • Loading branch information
robertknight authored Apr 23, 2024
2 parents 315fd11 + e49ab2c commit dfa490e
Show file tree
Hide file tree
Showing 30 changed files with 813 additions and 277 deletions.
5 changes: 3 additions & 2 deletions rten-examples/src/jina_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::error::Error;
use std::fs;

use rten::ops::concat;
use rten::{FloatOperators, Input, Model, NodeId, Operators};
use rten::{FloatOperators, Input, Model, NodeId, Operators, TensorPool};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};
use rten_text::tokenizers::{EncodeOptions, Tokenizer};
Expand Down Expand Up @@ -161,7 +161,8 @@ fn embed_sentence_batch(
view
})
.collect();
let mean_pooled: NdTensor<f32, 2> = concat(&mean_pooled_views, 0)?.try_into()?;
let pool = TensorPool::new();
let mean_pooled: NdTensor<f32, 2> = concat(&pool, &mean_pooled_views, 0)?.try_into()?;
Ok(mean_pooled)
}

Expand Down
3 changes: 2 additions & 1 deletion rten-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ pub use iterators::{
IterMut, Lanes, LanesMut, Offsets,
};
pub use layout::{
is_valid_permutation, DynLayout, Layout, MatrixLayout, MutLayout, NdLayout, OverlapPolicy,
is_valid_permutation, DynLayout, IntoLayout, Layout, MatrixLayout, MutLayout, NdLayout,
OverlapPolicy,
};
pub use slice_range::{to_slice_items, DynSliceItems, IntoSliceItems, SliceItem, SliceRange};

Expand Down
61 changes: 61 additions & 0 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,12 @@ impl<T, L: Clone + MutLayout> TensorBase<T, Vec<T>, L> {
}
}

/// Consume self and return the underlying data in whatever order the
/// elements are currently stored.
pub fn into_non_contiguous_data(self) -> Vec<T> {
self.data
}

/// Consume self and return a new contiguous tensor with the given shape.
///
/// This avoids copying the data if it is already contiguous.
Expand Down Expand Up @@ -861,6 +867,27 @@ where
element_type: PhantomData,
}
}

/// Initialize this tensor with data from another view.
///
/// This tensor and `other` must have the same shape.
pub fn init_from<S2: AsRef<[T]>>(
mut self,
other: &TensorBase<T, S2, L>,
) -> TensorBase<T, <S as AssumeInit>::Output, L>
where
T: Copy,
S: AsMut<[MaybeUninit<T>]>,
{
assert_eq!(self.shape(), other.shape(), "shape mismatch");
if let Some(data) = other.data() {
let data: &[MaybeUninit<T>] = unsafe { std::mem::transmute(data) };
self.data.as_mut().clone_from_slice(data);
} else {
copy_contiguous(other.as_dyn(), self.data.as_mut());
}
unsafe { self.assume_init() }
}
}

impl<'a, T, L: Clone + MutLayout> TensorBase<T, &'a [T], L> {
Expand Down Expand Up @@ -2173,10 +2200,44 @@ mod tests {
assert_eq!(tensor[[1, 1]], 9.);
}

#[test]
fn test_init_from() {
// Contiguous case
let src = NdTensor::arange(0, 4, None).into_shape([2, 2]);
let dest = NdTensor::uninit([2, 2]);
let dest = dest.init_from(&src);
assert_eq!(dest.to_vec(), &[0, 1, 2, 3]);

// Non-contigous
let dest = NdTensor::uninit([2, 2]);
let dest = dest.init_from(&src.transposed());
assert_eq!(dest.to_vec(), &[0, 2, 1, 3]);
}

#[test]
#[should_panic(expected = "shape mismatch")]
fn test_init_from_shape_mismatch() {
let src = NdTensor::arange(0, 4, None).into_shape([2, 2]);
let dest = NdTensor::uninit([2, 3]);
let dest = dest.init_from(&src);
assert_eq!(dest.to_vec(), &[0, 1, 2, 3]);
}

#[test]
fn test_into_data() {
let tensor = NdTensor::from_data([2], vec![2., 3.]);
assert_eq!(tensor.into_data(), vec![2., 3.]);

let mut tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]);
tensor.transpose();
assert_eq!(tensor.into_data(), vec![1., 3., 2., 4.]);
}

#[test]
fn test_into_non_contiguous_data() {
let mut tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]);
tensor.transpose();
assert_eq!(tensor.into_non_contiguous_data(), vec![1., 2., 3., 4.]);
}

#[test]
Expand Down
37 changes: 30 additions & 7 deletions src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::env;
use std::error::Error;
use std::fmt;
use std::iter::zip;
Expand All @@ -11,6 +12,7 @@ use rten_tensor::Tensor;
use rustc_hash::{FxHashMap, FxHashSet};

use crate::ops::{Input, InputList, OpError, Operator, Output};
use crate::tensor_pool::TensorPool;
use crate::timer::Timer;
use crate::timing::{InputShape, RunTiming, TimingRecord, TimingSort};

Expand Down Expand Up @@ -401,6 +403,14 @@ impl Graph {
temp_value_refcount.inc(*node_id);
}

// Create a pool to re-use buffers across execution steps.
//
// If the feature flag is off, we still create the pool, but never
// release buffers back into it, so all allocations use the system
// allocator.
let pool = TensorPool::new();
let use_pool = env::var_os("RTEN_USE_POOL").is_some();

// Execute the plan
let mut temp_values: FxHashMap<NodeId, Output> = FxHashMap::default();
let mut op_elapsed: Vec<TimingRecord> = Vec::new();
Expand Down Expand Up @@ -508,7 +518,9 @@ impl Graph {
.run_in_place(input, InputList::from_optional(op_inputs))
.map(|out| [out].into())
} else {
op_node.operator.run(InputList::from_optional(op_inputs))
op_node
.operator
.run(&pool, InputList::from_optional(op_inputs))
};

if record_timing {
Expand Down Expand Up @@ -581,7 +593,12 @@ impl Graph {
for node_id in op_node.inputs.iter().filter_map(|node| *node) {
let rc = temp_value_refcount.dec(node_id);
if rc == Some(0) {
temp_values.remove(&node_id);
if let (true, Some(tensor)) = (use_pool, temp_values.remove(&node_id)) {
match tensor {
Output::FloatTensor(t) => pool.add(t),
Output::IntTensor(t) => pool.add(t),
}
}
}
}
record_timing.then(|| alloc_timer.end());
Expand All @@ -594,6 +611,11 @@ impl Graph {
plan.len(),
run_timer.elapsed_ms()
);
println!(
"Pool allocs {} hits {}",
pool.alloc_count(),
pool.hit_count()
);
let timing = RunTiming {
records: &op_elapsed,
alloc_time: alloc_timer.elapsed_ms(),
Expand Down Expand Up @@ -851,6 +873,7 @@ mod tests {
use crate::ops::{
Add, Concat, Conv, InputList, IntoOpResult, OpError, Operator, Output, Relu, Shape,
};
use crate::tensor_pool::TensorPool;

#[derive(Clone, Debug, Default)]
struct Metrics {
Expand Down Expand Up @@ -894,12 +917,12 @@ mod tests {
self.inner.is_commutative()
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
{
let mut m = self.metrics.lock().unwrap();
m.run_count += 1;
}
self.inner.run(inputs)
self.inner.run(pool, inputs)
}

fn run_in_place(&self, output: Output, inputs: InputList) -> Result<Output, OpError> {
Expand Down Expand Up @@ -1061,7 +1084,7 @@ mod tests {
"AddOne"
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input: TensorView<f32> = inputs.require_as(0)?;
let output_data: Vec<f32> = input.iter().map(|x| x + 1.0).collect();
Tensor::<f32>::from_data(input.shape().into(), output_data).into_op_result()
Expand Down Expand Up @@ -1324,7 +1347,7 @@ mod tests {
true
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
// An operator should normally have the same behavior in `run`
// and `run_in_place`. Here we use different behavior to make it
// possible to distinguish which path was used.
Expand Down Expand Up @@ -1485,7 +1508,7 @@ mod tests {
"Split"
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
{
let mut rc = self.run_count.lock().unwrap();
*rc += 1;
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ mod model;
mod model_metadata;
mod number;
mod slice_reductions;
mod tensor_pool;
mod timer;
mod timing;

Expand All @@ -67,6 +68,7 @@ pub use graph::{Dimension, NodeId, RunOptions};
pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry, ReadOpError};
pub use model_metadata::ModelMetadata;
pub use ops::{FloatOperators, Input, Operators, Output};
pub use tensor_pool::TensorPool;
pub use timer::Timer;
pub use timing::TimingSort;

Expand Down
Loading

0 comments on commit dfa490e

Please sign in to comment.