Skip to content

Commit

Permalink
Merge pull request #115 from robertknight/gemm-op-pool
Browse files Browse the repository at this point in the history
Convert Gemm, Expand ops to use pool
  • Loading branch information
robertknight authored Apr 25, 2024
2 parents ea1e611 + 15c974e commit a886381
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 38 deletions.
1 change: 1 addition & 0 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ impl<'a> GemmInputB<'a> {
///
/// This computes `output = alpha * (a @ b) + beta * output` where `@` is
/// matrix multiplication.
#[allow(unused)]
pub fn gemm(
out_data: &mut [f32],
out_row_stride: usize,
Expand Down
47 changes: 29 additions & 18 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,23 @@ fn expand_output_shape(

/// Broadcast `input` to `out_shape`. This assumes that `out_shape` has already
/// been verified to be a valid broadcast target.
pub(crate) fn expand_to<T: Copy>(input: TensorView<T>, out_shape: &[usize]) -> Tensor<T> {
pub(crate) fn expand_to<T: Any + Copy>(
pool: &TensorPool,
input: TensorView<T>,
out_shape: &[usize],
) -> Tensor<T> {
let out_len = out_shape.iter().product();
let mut out_data: Vec<T> = pool.alloc_vec(out_len);

match (
input.data(),
fast_broadcast_cycles_repeats(input.shape(), out_shape),
) {
// Fast path for common case of contiguous input and broadcast that can
// be performed using cycle + repeat.
(Some(in_data), Some((cycles, repeats))) => {
let out_len = out_shape.iter().product();
assert!(out_len == input.len() * cycles * repeats);

let mut out_data: Vec<T> = Vec::with_capacity(out_len);
let mut out_ptr = out_data.as_mut_ptr();
for _ in 0..cycles {
if repeats == 1 {
Expand All @@ -63,16 +68,17 @@ pub(crate) fn expand_to<T: Copy>(input: TensorView<T>, out_shape: &[usize]) -> T

Tensor::from_data(out_shape, out_data)
}
_ => input.broadcast(out_shape).to_tensor(),
_ => input.broadcast(out_shape).to_tensor_buf(out_data),
}
}

pub fn expand<T: Copy>(
pub fn expand<T: Any + Copy>(
pool: &TensorPool,
input: TensorView<T>,
shape: &NdTensorView<i32, 1>,
) -> Result<Tensor<T>, OpError> {
let out_shape = expand_output_shape(input.shape(), shape)?;
Ok(expand_to(input, &out_shape))
Ok(expand_to(pool, input, &out_shape))
}

#[derive(Debug)]
Expand All @@ -83,14 +89,14 @@ impl Operator for Expand {
"Expand"
}

fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input = inputs.require(0)?;
let shape = inputs.require_as(1)?;
let shape = static_dims!(shape, 1)?;

match input {
Input::FloatTensor(input) => expand(input, &shape).into_op_result(),
Input::IntTensor(input) => expand(input, &shape).into_op_result(),
Input::FloatTensor(input) => expand(pool, input, &shape).into_op_result(),
Input::IntTensor(input) => expand(pool, input, &shape).into_op_result(),
}
}

Expand All @@ -109,9 +115,10 @@ impl Operator for Expand {
return Ok(input);
}

let pool = TensorPool::new();
let output: Output = match input {
Output::FloatTensor(input) => expand_to(input.view(), &out_shape).into(),
Output::IntTensor(input) => expand_to(input.view(), &out_shape).into(),
Output::FloatTensor(input) => expand_to(&pool, input.view(), &out_shape).into(),
Output::IntTensor(input) => expand_to(&pool, input.view(), &out_shape).into(),
};
Ok(output)
}
Expand Down Expand Up @@ -563,53 +570,57 @@ mod tests {

#[test]
fn test_expand() {
let pool = new_pool();

// Broadcast scalar
let input = tensor!(5.);
let shape = ndtensor!([2, 2]);
let expected = Tensor::from_data(&[2, 2], vec![5., 5., 5., 5.]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(&result, &expected);

// Broadcast that changes dim count
let input = Tensor::from_data(&[3, 1], (0..3).collect::<Vec<_>>());
let shape = ndtensor!([2, 3, 1]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[2, 3, 1]);

// Broadcast that uses dimensions from both the input shape and target
// shape in the output shape.
let input = Tensor::from_data(&[3, 1], (0..3).collect::<Vec<_>>());
let shape = ndtensor!([2, 1, 6]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[2, 3, 6]);

// Broadcast that does not change dim count
let input = Tensor::from_data(&[3, 1], (0..3).collect::<Vec<_>>());
let shape = ndtensor!([3, 4]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[3, 4]);

// Broadcast of leading and trailing dims
let input = tensor!((1, 2, 1); [1, 2]);
let shape = ndtensor!([2, 2, 2]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[2, 2, 2]);
assert_eq!(result.to_vec(), &[1, 1, 2, 2, 1, 1, 2, 2]);

// Broadcast of inner dim
let input = tensor!((2, 1, 2); [1, 2, 3, 4]);
let shape = ndtensor!([2, 2, 2]);
let result = expand(input.view(), &shape.view()).unwrap();
let result = expand(&pool, input.view(), &shape.view()).unwrap();
assert_eq!(result.shape(), &[2, 2, 2]);
assert_eq!(result.to_vec(), &[1, 2, 1, 2, 3, 4, 3, 4]);
}

#[test]
fn test_expand_invalid_inputs() {
let pool = new_pool();

// Invalid broadcast shape
let input = tensor!([1, 2, 3]);
let shape = ndtensor!([2, 2]);
let result = expand(input.view(), &shape.view());
let result = expand(&pool, input.view(), &shape.view());
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes(
Expand Down
82 changes: 62 additions & 20 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rten_tensor::prelude::*;
use rten_tensor::{Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{gemm, GemmExecutor, GemmInputA, GemmInputB};
use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::layout::expand_to;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, Output};
Expand All @@ -25,6 +25,7 @@ pub struct Gemm {
///
/// nb. This is named `gemm_op` to avoid confusion with `gemm::gemm`.
pub fn gemm_op(
pool: &TensorPool,
a: TensorView,
b: TensorView,
c: Option<TensorView>,
Expand All @@ -40,29 +41,42 @@ pub fn gemm_op(
let b = if transpose_b { b.transposed() } else { b };

let out_shape = &[a.size(0), b.size(1)][..];
let mut output = match c {
let gemm = GemmExecutor::new();

let output = match c {
Some(c) if beta != 0. => {
if !c.can_broadcast_to(out_shape) {
return Err(OpError::IncompatibleInputShapes(
"Cannot broadcast c to output shape",
));
}
expand_to(c, out_shape)
let mut output = expand_to(pool, c, out_shape);
let out_row_stride = output.stride(0);
gemm.gemm(
output.data_mut().unwrap(),
out_row_stride,
GemmInputA::Unpacked(a.nd_view()),
GemmInputB::Unpacked(b.nd_view()),
alpha,
beta,
);
output
}
_ => {
let mut output = pool.alloc(out_shape);
let out_row_stride = output.stride(0);
gemm.gemm_uninit(
output.data_mut().unwrap(),
out_row_stride,
GemmInputA::Unpacked(a.nd_view()),
GemmInputB::Unpacked(b.nd_view()),
alpha,
);
// Safety: `gemm_uninit` initialized all elements
unsafe { output.assume_init() }
}
_ => Tensor::zeros(out_shape),
};

let out_row_stride = output.stride(0);

gemm(
output.data_mut().unwrap(),
out_row_stride,
a.nd_view(),
b.nd_view(),
alpha,
beta,
);

Ok(output)
}

Expand All @@ -71,11 +85,12 @@ impl Operator for Gemm {
"Gemm"
}

fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<Vec<Output>, OpError> {
let a = inputs.require_as(0)?;
let b = inputs.require_as(1)?;
let c = inputs.get_as(2)?;
gemm_op(
pool,
a,
b,
c,
Expand Down Expand Up @@ -293,14 +308,16 @@ mod tests {

#[test]
fn test_gemm_op() -> Result<(), Box<dyn Error>> {
let pool = new_pool();

let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[3, 10], &mut rng);
let b = Tensor::rand(&[10, 8], &mut rng);

let mut expected = Tensor::zeros(&[3, 8]);
gemm_tensors(&mut expected, &a, &b, 1., 1.);

let result = gemm_op(a.view(), b.view(), None, 1.0, 1.0, false, false).unwrap();
let result = gemm_op(&pool, a.view(), b.view(), None, 1.0, 1.0, false, false).unwrap();

expect_equal(&result, &expected)?;

Expand All @@ -309,6 +326,8 @@ mod tests {

#[test]
fn test_gemm_op_transposed() -> Result<(), Box<dyn Error>> {
let pool = new_pool();

let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[10, 3], &mut rng);
let b = Tensor::rand(&[8, 10], &mut rng);
Expand All @@ -320,7 +339,7 @@ mod tests {
let mut expected = Tensor::zeros(&[3, 8]);
gemm_tensors(&mut expected, &a_transposed, &b_transposed, 1., 1.);

let result = gemm_op(a.view(), b.view(), None, 1.0, 1.0, true, true).unwrap();
let result = gemm_op(&pool, a.view(), b.view(), None, 1.0, 1.0, true, true).unwrap();

expect_equal(&result, &expected)?;

Expand All @@ -329,6 +348,8 @@ mod tests {

#[test]
fn test_gemm_op_adds_c() -> Result<(), Box<dyn Error>> {
let pool = new_pool();

let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[3, 10], &mut rng);
let b = Tensor::rand(&[10, 8], &mut rng);
Expand All @@ -337,7 +358,17 @@ mod tests {
let mut expected = c.clone();
gemm_tensors(&mut expected, &a, &b, 1., 1.);

let result = gemm_op(a.view(), b.view(), Some(c.view()), 1.0, 1.0, false, false).unwrap();
let result = gemm_op(
&pool,
a.view(),
b.view(),
Some(c.view()),
1.0,
1.0,
false,
false,
)
.unwrap();

expect_equal(&result, &expected)?;

Expand All @@ -346,12 +377,23 @@ mod tests {

#[test]
fn test_gemm_op_invalid_inputs() {
let pool = new_pool();

let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[3, 10], &mut rng);
let b = Tensor::rand(&[10, 8], &mut rng);
let c = Tensor::rand(&[3, 5], &mut rng);

let result = gemm_op(a.view(), b.view(), Some(c.view()), 1.0, 1.0, false, false);
let result = gemm_op(
&pool,
a.view(),
b.view(),
Some(c.view()),
1.0,
1.0,
false,
false,
);

assert_eq!(
result.err(),
Expand Down

0 comments on commit a886381

Please sign in to comment.