Skip to content

Commit

Permalink
Convert Gemm, Expand ops to use pool
Browse files Browse the repository at this point in the history
While converting the Gemm operator, also modify it to avoid redundant zeroing of
the output buffer if `beta` is zero. The same optimization was applied to the
MatMul operator previously.
  • Loading branch information
robertknight committed Apr 25, 2024
1 parent ea1e611 commit 15c974e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 38 deletions.
1 change: 1 addition & 0 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ impl<'a> GemmInputB<'a> {
///
/// This computes `output = alpha * (a @ b) + beta * output` where `@` is
/// matrix multiplication.
#[allow(unused)]
pub fn gemm(
out_data: &mut [f32],
out_row_stride: usize,
Expand Down
47 changes: 29 additions & 18 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,23 @@ fn expand_output_shape(

/// Broadcast `input` to `out_shape`. This assumes that `out_shape` has already
/// been verified to be a valid broadcast target.
pub(crate) fn expand_to<T: Copy>(input: TensorView<T>, out_shape: &[usize]) -> Tensor<T> {
pub(crate) fn expand_to<T: Any + Copy>(
pool: &TensorPool,
input: TensorView<T>,
out_shape: &[usize],
) -> Tensor<T> {
let out_len = out_shape.iter().product();
let mut out_data: Vec<T> = pool.alloc_vec(out_len);

match (
input.data(),
fast_broadcast_cycles_repeats(input.shape(), out_shape),
) {
// Fast path for common case of contiguous input and broadcast that can
// be performed using cycle + repeat.
(Some(in_data), Some((cycles, repeats))) => {
let out_len = out_shape.iter().product();
assert!(out_len == input.len() * cycles * repeats);

let mut out_data: Vec<T> = Vec::with_capacity(out_len);
let mut out_ptr = out_data.as_mut_ptr();
for _ in 0..cycles {
if repeats == 1 {
Expand All @@ -63,16 +68,17 @@ pub(crate) fn expand_to<T: Copy>(input: TensorView<T>, out_shape: &[usize]) -> T

Tensor::from_data(out_shape, out_data)
}
_ => input.broadcast(out_shape).to_tensor(),
_ => input.broadcast(out_shape).to_tensor_buf(out_data),
}
}

pub fn expand<T: Copy>(
pub fn expand<T: Any + Copy>(
pool: &TensorPool,
input: TensorView<T>,
shape: &NdTensorView<i32, 1>,
) -> Result<Tensor<T>, OpError> {
let out_shape = expand_output_shape(input.shape(), shape)?;
Ok(expand_to(input, &out_shape))
Ok(expand_to(pool, input, &out_shape))
}

#[derive(Debug)]
Expand All @@ -83,14 +89,14 @@ impl Operator for Expand {
"Expand"
}

fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input = inputs.require(0)?;
let shape = inputs.require_as(1)?;
let shape = static_dims!(shape, 1)?;

match input {
Input::FloatTensor(input) => expand(input, &shape).into_op_result(),
Input::IntTensor(input) => expand(input, &shape).into_op_result(),
Input::FloatTensor(input) => expand(pool, input, &shape).into_op_result(),
Input::IntTensor(input) => expand(pool, input, &shape).into_op_result(),
}
}

Expand All @@ -109,9 +115,10 @@ impl Operator for Expand {
return Ok(input);
}

let pool = TensorPool::new();
let output: Output = match input {
Output::FloatTensor(input) => expand_to(input.view(), &out_shape).into(),
Output::IntTensor(input) => expand_to(input.view(), &out_shape).into(),
Output::FloatTensor(input) => expand_to(&pool, input.view(), &out_shape).into(),
Output::IntTensor(input) => expand_to(&pool, input.view(), &out_shape).into(),
};
Ok(output)
}
Expand Down Expand Up @@ -563,53 +570,57 @@ mod tests {

#[test]
fn test_expand() {
let pool = new_pool();

// Broadcast scalar
let input = tensor!(5.);
let shape = ndtensor!([2, 2]);
let expected = Tensor::from_data(&[2, 2], vec![5., 5., 5., 5.]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(&result, &expected);

// Broadcast that changes dim count
let input = Tensor::from_data(&[3, 1], (0..3).collect::<Vec<_>>());
let shape = ndtensor!([2, 3, 1]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[2, 3, 1]);

// Broadcast that uses dimensions from both the input shape and target
// shape in the output shape.
let input = Tensor::from_data(&[3, 1], (0..3).collect::<Vec<_>>());
let shape = ndtensor!([2, 1, 6]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[2, 3, 6]);

// Broadcast that does not change dim count
let input = Tensor::from_data(&[3, 1], (0..3).collect::<Vec<_>>());
let shape = ndtensor!([3, 4]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[3, 4]);

// Broadcast of leading and trailing dims
let input = tensor!((1, 2, 1); [1, 2]);
let shape = ndtensor!([2, 2, 2]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[2, 2, 2]);
assert_eq!(result.to_vec(), &[1, 1, 2, 2, 1, 1, 2, 2]);

// Broadcast of inner dim
let input = tensor!((2, 1, 2); [1, 2, 3, 4]);
let shape = ndtensor!([2, 2, 2]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[2, 2, 2]);
assert_eq!(result.to_vec(), &[1, 2, 1, 2, 3, 4, 3, 4]);
}

#[test]
fn test_expand_invalid_inputs() {
let pool = new_pool();

// Invalid broadcast shape
let input = tensor!([1, 2, 3]);
let shape = ndtensor!([2, 2]);
let result = expand(input.view(), &shape.view());
let result = expand(&pool, input.view(), &shape.view());
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes(
Expand Down
82 changes: 62 additions & 20 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rten_tensor::prelude::*;
use rten_tensor::{Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{gemm, GemmExecutor, GemmInputA, GemmInputB};
use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::layout::expand_to;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, Output};
Expand All @@ -25,6 +25,7 @@ pub struct Gemm {
///
/// nb. This is named `gemm_op` to avoid confusion with `gemm::gemm`.
pub fn gemm_op(
pool: &TensorPool,
a: TensorView,
b: TensorView,
c: Option<TensorView>,
Expand All @@ -40,29 +41,42 @@ pub fn gemm_op(
let b = if transpose_b { b.transposed() } else { b };

let out_shape = &[a.size(0), b.size(1)][..];
let mut output = match c {
let gemm = GemmExecutor::new();

let output = match c {
Some(c) if beta != 0. => {
if !c.can_broadcast_to(out_shape) {
return Err(OpError::IncompatibleInputShapes(
"Cannot broadcast c to output shape",
));
}
expand_to(c, out_shape)
let mut output = expand_to(pool, c, out_shape);
let out_row_stride = output.stride(0);
gemm.gemm(
output.data_mut().unwrap(),
out_row_stride,
GemmInputA::Unpacked(a.nd_view()),
GemmInputB::Unpacked(b.nd_view()),
alpha,
beta,
);
output
}
_ => {
let mut output = pool.alloc(out_shape);
let out_row_stride = output.stride(0);
gemm.gemm_uninit(
output.data_mut().unwrap(),
out_row_stride,
GemmInputA::Unpacked(a.nd_view()),
GemmInputB::Unpacked(b.nd_view()),
alpha,
);
// Safety: `gemm_uninit` initialized all elements
unsafe { output.assume_init() }
}
_ => Tensor::zeros(out_shape),
};

let out_row_stride = output.stride(0);

gemm(
output.data_mut().unwrap(),
out_row_stride,
a.nd_view(),
b.nd_view(),
alpha,
beta,
);

Ok(output)
}

Expand All @@ -71,11 +85,12 @@ impl Operator for Gemm {
"Gemm"
}

fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
let a = inputs.require_as(0)?;
let b = inputs.require_as(1)?;
let c = inputs.get_as(2)?;
gemm_op(
pool,
a,
b,
c,
Expand Down Expand Up @@ -293,14 +308,16 @@ mod tests {

#[test]
fn test_gemm_op() -> Result<(), Box<dyn Error>> {
let pool = new_pool();

let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[3, 10], &mut rng);
let b = Tensor::rand(&[10, 8], &mut rng);

let mut expected = Tensor::zeros(&[3, 8]);
gemm_tensors(&mut expected, &a, &b, 1., 1.);

let result = gemm_op(a.view(), b.view(), None, 1.0, 1.0, false, false).unwrap();
let result = gemm_op(&pool, a.view(), b.view(), None, 1.0, 1.0, false, false).unwrap();

expect_equal(&result, &expected)?;

Expand All @@ -309,6 +326,8 @@ mod tests {

#[test]
fn test_gemm_op_transposed() -> Result<(), Box<dyn Error>> {
let pool = new_pool();

let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[10, 3], &mut rng);
let b = Tensor::rand(&[8, 10], &mut rng);
Expand All @@ -320,7 +339,7 @@ mod tests {
let mut expected = Tensor::zeros(&[3, 8]);
gemm_tensors(&mut expected, &a_transposed, &b_transposed, 1., 1.);

let result = gemm_op(a.view(), b.view(), None, 1.0, 1.0, true, true).unwrap();
let result = gemm_op(&pool, a.view(), b.view(), None, 1.0, 1.0, true, true).unwrap();

expect_equal(&result, &expected)?;

Expand All @@ -329,6 +348,8 @@ mod tests {

#[test]
fn test_gemm_op_adds_c() -> Result<(), Box<dyn Error>> {
let pool = new_pool();

let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[3, 10], &mut rng);
let b = Tensor::rand(&[10, 8], &mut rng);
Expand All @@ -337,7 +358,17 @@ mod tests {
let mut expected = c.clone();
gemm_tensors(&mut expected, &a, &b, 1., 1.);

let result = gemm_op(a.view(), b.view(), Some(c.view()), 1.0, 1.0, false, false).unwrap();
let result = gemm_op(
&pool,
a.view(),
b.view(),
Some(c.view()),
1.0,
1.0,
false,
false,
)
.unwrap();

expect_equal(&result, &expected)?;

Expand All @@ -346,12 +377,23 @@ mod tests {

#[test]
fn test_gemm_op_invalid_inputs() {
let pool = new_pool();

let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[3, 10], &mut rng);
let b = Tensor::rand(&[10, 8], &mut rng);
let c = Tensor::rand(&[3, 5], &mut rng);

let result = gemm_op(a.view(), b.view(), Some(c.view()), 1.0, 1.0, false, false);
let result = gemm_op(
&pool,
a.view(),
b.view(),
Some(c.view()),
1.0,
1.0,
false,
false,
);

assert_eq!(
result.err(),
Expand Down

0 comments on commit 15c974e

Please sign in to comment.