Skip to content

Commit

Permalink
Add an arena/pool to enable tensor buffer re-use during graph execution
Browse files Browse the repository at this point in the history
Improve buffer re-use during graph execution by adding a pool from which
operators can allocate output buffers, and into which buffers are added
when their ref count drops to zero (ie. when they are no longer needed
by subsequent graph execution steps). This significantly reduces how
often execution needs to allocate "fresh" buffers from the system
allocator and free them back.

In this initial implementation, a reference to the pool is passed to all
operators via `Operator::run`, but only a subset actually use the pool.
This subset was chosen to benefit the YOLOv8 example.

 - Add `pool` argument to `Operator::run`, specifying a pool from which
   operators should allocate their outputs

 - Create a pool at the start of graph execution and release it at the end.
   Intermediate values that are no longer needed are added to the pool after
   each operator runs.

 - Report the number of allocations from the pools and the hit rate (how often
   the pool was able to satisfy allocations) as part of timing info.

 - Modify an initial subset of allocators to allocate from the pool, based on
   what helps the YOLOv8 example.
  • Loading branch information
robertknight committed Apr 23, 2024
1 parent a3e2dd9 commit 11be82a
Show file tree
Hide file tree
Showing 28 changed files with 744 additions and 276 deletions.
5 changes: 3 additions & 2 deletions rten-examples/src/jina_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::error::Error;
use std::fs;

use rten::ops::concat;
use rten::{FloatOperators, Input, Model, NodeId, Operators};
use rten::{FloatOperators, Input, Model, NodeId, Operators, TensorPool};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};
use rten_text::tokenizers::{EncodeOptions, Tokenizer};
Expand Down Expand Up @@ -161,7 +161,8 @@ fn embed_sentence_batch(
view
})
.collect();
let mean_pooled: NdTensor<f32, 2> = concat(&mean_pooled_views, 0)?.try_into()?;
let pool = TensorPool::new();
let mean_pooled: NdTensor<f32, 2> = concat(&pool, &mean_pooled_views, 0)?.try_into()?;
Ok(mean_pooled)
}

Expand Down
31 changes: 24 additions & 7 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rten_tensor::Tensor;
use rustc_hash::{FxHashMap, FxHashSet};

use crate::ops::{Input, InputList, OpError, Operator, Output};
use crate::tensor_pool::TensorPool;
use crate::timer::Timer;
use crate::timing::{InputShape, RunTiming, TimingRecord, TimingSort};

Expand Down Expand Up @@ -401,6 +402,9 @@ impl Graph {
temp_value_refcount.inc(*node_id);
}

// Create a pool to re-use buffers across execution steps.
let pool = TensorPool::new();

// Execute the plan
let mut temp_values: FxHashMap<NodeId, Output> = FxHashMap::default();
let mut op_elapsed: Vec<TimingRecord> = Vec::new();
Expand Down Expand Up @@ -508,7 +512,9 @@ impl Graph {
.run_in_place(input, InputList::from_optional(op_inputs))
.map(|out| [out].into())
} else {
op_node.operator.run(InputList::from_optional(op_inputs))
op_node
.operator
.run(&pool, InputList::from_optional(op_inputs))
};

if record_timing {
Expand Down Expand Up @@ -581,7 +587,12 @@ impl Graph {
for node_id in op_node.inputs.iter().filter_map(|node| *node) {
let rc = temp_value_refcount.dec(node_id);
if rc == Some(0) {
temp_values.remove(&node_id);
if let Some(tensor) = temp_values.remove(&node_id) {
match tensor {
Output::FloatTensor(t) => pool.add(t),
Output::IntTensor(t) => pool.add(t),
}
}
}
}
record_timing.then(|| alloc_timer.end());
Expand All @@ -594,6 +605,11 @@ impl Graph {
plan.len(),
run_timer.elapsed_ms()
);
println!(
"Pool allocs {} hits {}",
pool.alloc_count(),
pool.hit_count()
);
let timing = RunTiming {
records: &op_elapsed,
alloc_time: alloc_timer.elapsed_ms(),
Expand Down Expand Up @@ -851,6 +867,7 @@ mod tests {
use crate::ops::{
Add, Concat, Conv, InputList, IntoOpResult, OpError, Operator, Output, Relu, Shape,
};
use crate::tensor_pool::TensorPool;

#[derive(Clone, Debug, Default)]
struct Metrics {
Expand Down Expand Up @@ -894,12 +911,12 @@ mod tests {
self.inner.is_commutative()
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
{
let mut m = self.metrics.lock().unwrap();
m.run_count += 1;
}
self.inner.run(inputs)
self.inner.run(pool, inputs)
}

fn run_in_place(&self, output: Output, inputs: InputList) -> Result<Output, OpError> {
Expand Down Expand Up @@ -1061,7 +1078,7 @@ mod tests {
"AddOne"
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input: TensorView<f32> = inputs.require_as(0)?;
let output_data: Vec<f32> = input.iter().map(|x| x + 1.0).collect();
Tensor::<f32>::from_data(input.shape().into(), output_data).into_op_result()
Expand Down Expand Up @@ -1324,7 +1341,7 @@ mod tests {
true
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
// An operator should normally have the same behavior in `run`
// and `run_in_place`. Here we use different behavior to make it
// possible to distinguish which path was used.
Expand Down Expand Up @@ -1485,7 +1502,7 @@ mod tests {
"Split"
}

fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
{
let mut rc = self.run_count.lock().unwrap();
*rc += 1;
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ mod model;
mod model_metadata;
mod number;
mod slice_reductions;
mod tensor_pool;
mod timer;
mod timing;

Expand All @@ -67,6 +68,7 @@ pub use graph::{Dimension, NodeId, RunOptions};
pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry, ReadOpError};
pub use model_metadata::ModelMetadata;
pub use ops::{FloatOperators, Input, Operators, Output};
pub use tensor_pool::TensorPool;
pub use timer::Timer;
pub use timing::TimingSort;

Expand Down
Loading

0 comments on commit 11be82a

Please sign in to comment.