Skip to content

Commit

Permalink
Convert remaining pooling ops to use pool and avoid redundant zeroing
Browse files Browse the repository at this point in the history
Convert remaining pooling ops to allocate from the pool and avoid redundant
zeroing of the output buffer.
  • Loading branch information
robertknight committed Apr 25, 2024
1 parent a886381 commit c50fa83
Showing 1 changed file with 32 additions and 12 deletions.
44 changes: 32 additions & 12 deletions src/ops/pooling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};

use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView};
use rten_tensor::{NdTensorView, NdTensorViewMut, Tensor, TensorView};

use crate::check_dims;
use crate::gemm::div_ceil;
Expand Down Expand Up @@ -90,6 +90,7 @@ pub fn calc_output_size_and_padding(
}

pub fn average_pool(
pool: &TensorPool,
input: TensorView,
kernel_size: [usize; 2],
strides: [usize; 2],
Expand All @@ -108,9 +109,10 @@ pub fn average_pool(
let [kernel_h, kernel_w] = kernel_size;
let [stride_h, stride_w] = strides;

let mut output = NdTensor::zeros([batch, in_c, out_h, out_w]);
let mut output = pool.alloc([batch, in_c, out_h, out_w]);
let input = input.nd_view::<4>();

let mut n_init = 0;
for n in 0..batch {
for chan in 0..in_c {
let mut out_view = output.slice_mut([n, chan]);
Expand Down Expand Up @@ -144,12 +146,16 @@ pub fn average_pool(
non_padding_elements
};

out_view[[out_y, out_x]] = accumulator / counted_elems;
out_view[[out_y, out_x]].write(accumulator / counted_elems);
n_init += 1;
}
}
}
}

assert!(n_init == output.len());
let output = unsafe { output.assume_init() };

Ok(output.into_dyn())
}

Expand All @@ -166,9 +172,10 @@ impl Operator for AveragePool {
"AveragePool"
}

fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input = inputs.require_as(0)?;
average_pool(
pool,
input,
self.kernel_size,
self.strides,
Expand All @@ -179,10 +186,11 @@ impl Operator for AveragePool {
}
}

pub fn global_average_pool(input: TensorView) -> Result<Tensor, OpError> {
pub fn global_average_pool(pool: &TensorPool, input: TensorView) -> Result<Tensor, OpError> {
let [batch, chans, in_h, in_w] = check_dims!(input, 4, "NCHW");

let mut output = Tensor::zeros(&[batch, chans, 1, 1]);
let mut output = pool.alloc([batch, chans, 1, 1]);
let mut n_init = 0;

for n in 0..batch {
const N: usize = 4;
Expand All @@ -208,19 +216,24 @@ pub fn global_average_pool(input: TensorView) -> Result<Tensor, OpError> {
}

for i in 0..N {
out_group[[i]] = sums[i] / (in_h * in_w) as f32;
out_group[[i]].write(sums[i] / (in_h * in_w) as f32);
}
n_init += N;
} else {
// Compute average over remaining channels.
for i in 0..chan_group.size(0) {
let sum: f32 = chan_group.slice::<2, _>([i]).iter().sum();
out_group[[i]] = sum / (in_h * in_w) as f32;
out_group[[i]].write(sum / (in_h * in_w) as f32);
n_init += 1;
}
}
}
}

Ok(output)
assert!(n_init == output.len());
let output = unsafe { output.assume_init() };

Ok(output.into_dyn())
}

#[derive(Debug)]
Expand All @@ -231,9 +244,9 @@ impl Operator for GlobalAveragePool {
"GlobalAveragePool"
}

fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input = inputs.require_as(0)?;
global_average_pool(input).into_op_result()
global_average_pool(pool, input).into_op_result()
}
}

Expand Down Expand Up @@ -454,8 +467,10 @@ mod tests {
},
];

let pool = new_pool();
for case in cases {
let result = average_pool(
&pool,
input.view(),
case.kernel_size,
case.strides,
Expand All @@ -471,6 +486,8 @@ mod tests {

#[test]
fn test_average_pool_padding() -> Result<(), Box<dyn Error>> {
let pool = new_pool();

let mut input = Tensor::from([
[0.0809, 0.5529, 0.1534, 0.7507],
[0.4698, 0.7771, 0.9896, 0.4873],
Expand All @@ -491,6 +508,7 @@ mod tests {
expected.reshape(&[1, 1, rows, cols]);

let result = average_pool(
&pool,
input.view(),
[2, 2],
[2, 2], /* stride */
Expand All @@ -509,6 +527,7 @@ mod tests {
.into_shape([1, 1, 3, 3])
.into_dyn();
let result = average_pool(
&pool,
input.view(),
[2, 2],
[2, 2], /* stride */
Expand All @@ -523,9 +542,10 @@ mod tests {

#[test]
fn test_global_average_pool() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let input = Tensor::from_data(&[1, 2, 2, 2], vec![1., 2., 3., 4., 10., 20., 30., 40.]);
let expected = Tensor::from_data(&[1, 2, 1, 1], vec![2.5, 25.]);
let result = global_average_pool(input.view()).unwrap();
let result = global_average_pool(&pool, input.view()).unwrap();
expect_equal(&result, &expected)?;
Ok(())
}
Expand Down

0 comments on commit c50fa83

Please sign in to comment.