From 15c974e76176e378be8d6f46a9b1329831321e32 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Thu, 25 Apr 2024 19:57:31 +0100 Subject: [PATCH] Convert Gemm, Expand ops to use pool While converting the Gemm operator, also modify it to avoid redundant zeroing of the output buffer if `beta` is zero. The same optimization was applied to the MatMul operator previously. --- src/gemm.rs | 1 + src/ops/layout.rs | 47 ++++++++++++++++----------- src/ops/matmul.rs | 82 +++++++++++++++++++++++++++++++++++------------ 3 files changed, 92 insertions(+), 38 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index 603c7e41..069509c4 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -230,6 +230,7 @@ impl<'a> GemmInputB<'a> { /// /// This computes `output = alpha * (a @ b) + beta * output` where `@` is /// matrix multiplication. +#[allow(unused)] pub fn gemm( out_data: &mut [f32], out_row_stride: usize, diff --git a/src/ops/layout.rs b/src/ops/layout.rs index 32762709..d04c56e6 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -27,7 +27,14 @@ fn expand_output_shape( /// Broadcast `input` to `out_shape`. This assumes that `out_shape` has already /// been verified to be a valid broadcast target. -pub(crate) fn expand_to(input: TensorView, out_shape: &[usize]) -> Tensor { +pub(crate) fn expand_to( + pool: &TensorPool, + input: TensorView, + out_shape: &[usize], +) -> Tensor { + let out_len = out_shape.iter().product(); + let mut out_data: Vec = pool.alloc_vec(out_len); + match ( input.data(), fast_broadcast_cycles_repeats(input.shape(), out_shape), @@ -35,10 +42,8 @@ pub(crate) fn expand_to(input: TensorView, out_shape: &[usize]) -> T // Fast path for common case of contiguous input and broadcast that can // be performed using cycle + repeat. (Some(in_data), Some((cycles, repeats))) => { - let out_len = out_shape.iter().product(); assert!(out_len == input.len() * cycles * repeats); - let mut out_data: Vec = Vec::with_capacity(out_len); let mut out_ptr = out_data.as_mut_ptr(); for _ in 0..cycles { if repeats == 1 { @@ -63,16 +68,17 @@ pub(crate) fn expand_to(input: TensorView, out_shape: &[usize]) -> T Tensor::from_data(out_shape, out_data) } - _ => input.broadcast(out_shape).to_tensor(), + _ => input.broadcast(out_shape).to_tensor_buf(out_data), } } -pub fn expand( +pub fn expand( + pool: &TensorPool, input: TensorView, shape: &NdTensorView, ) -> Result, OpError> { let out_shape = expand_output_shape(input.shape(), shape)?; - Ok(expand_to(input, &out_shape)) + Ok(expand_to(pool, input, &out_shape)) } #[derive(Debug)] @@ -83,14 +89,14 @@ impl Operator for Expand { "Expand" } - fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result, OpError> { + fn run(&self, pool: &TensorPool, inputs: InputList) -> Result, OpError> { let input = inputs.require(0)?; let shape = inputs.require_as(1)?; let shape = static_dims!(shape, 1)?; match input { - Input::FloatTensor(input) => expand(input, &shape).into_op_result(), - Input::IntTensor(input) => expand(input, &shape).into_op_result(), + Input::FloatTensor(input) => expand(pool, input, &shape).into_op_result(), + Input::IntTensor(input) => expand(pool, input, &shape).into_op_result(), } } @@ -109,9 +115,10 @@ impl Operator for Expand { return Ok(input); } + let pool = TensorPool::new(); let output: Output = match input { - Output::FloatTensor(input) => expand_to(input.view(), &out_shape).into(), - Output::IntTensor(input) => expand_to(input.view(), &out_shape).into(), + Output::FloatTensor(input) => expand_to(&pool, input.view(), &out_shape).into(), + Output::IntTensor(input) => expand_to(&pool, input.view(), &out_shape).into(), }; Ok(output) } @@ -563,53 +570,57 @@ mod tests { #[test] fn test_expand() { + let pool = new_pool(); + // Broadcast scalar let input = tensor!(5.); let shape = ndtensor!([2, 2]); let expected = Tensor::from_data(&[2, 2], vec![5., 5., 5., 5.]); - let result = expand(input.view(), &shape.view()).unwrap(); + let result = expand(&pool, input.view(), &shape.view()).unwrap(); assert_eq!(&result, &expected); // Broadcast that changes dim count let input = Tensor::from_data(&[3, 1], (0..3).collect::>()); let shape = ndtensor!([2, 3, 1]); - let result = expand(input.view(), &shape.view()).unwrap(); + let result = expand(&pool, input.view(), &shape.view()).unwrap(); assert_eq!(result.shape(), &[2, 3, 1]); // Broadcast that uses dimensions from both the input shape and target // shape in the output shape. let input = Tensor::from_data(&[3, 1], (0..3).collect::>()); let shape = ndtensor!([2, 1, 6]); - let result = expand(input.view(), &shape.view()).unwrap(); + let result = expand(&pool, input.view(), &shape.view()).unwrap(); assert_eq!(result.shape(), &[2, 3, 6]); // Broadcast that does not change dim count let input = Tensor::from_data(&[3, 1], (0..3).collect::>()); let shape = ndtensor!([3, 4]); - let result = expand(input.view(), &shape.view()).unwrap(); + let result = expand(&pool, input.view(), &shape.view()).unwrap(); assert_eq!(result.shape(), &[3, 4]); // Broadcast of leading and trailing dims let input = tensor!((1, 2, 1); [1, 2]); let shape = ndtensor!([2, 2, 2]); - let result = expand(input.view(), &shape.view()).unwrap(); + let result = expand(&pool, input.view(), &shape.view()).unwrap(); assert_eq!(result.shape(), &[2, 2, 2]); assert_eq!(result.to_vec(), &[1, 1, 2, 2, 1, 1, 2, 2]); // Broadcast of inner dim let input = tensor!((2, 1, 2); [1, 2, 3, 4]); let shape = ndtensor!([2, 2, 2]); - let result = expand(input.view(), &shape.view()).unwrap(); + let result = expand(&pool, input.view(), &shape.view()).unwrap(); assert_eq!(result.shape(), &[2, 2, 2]); assert_eq!(result.to_vec(), &[1, 2, 1, 2, 3, 4, 3, 4]); } #[test] fn test_expand_invalid_inputs() { + let pool = new_pool(); + // Invalid broadcast shape let input = tensor!([1, 2, 3]); let shape = ndtensor!([2, 2]); - let result = expand(input.view(), &shape.view()); + let result = expand(&pool, input.view(), &shape.view()); assert_eq!( result.err(), Some(OpError::IncompatibleInputShapes( diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index 80b5b9e9..4fdf6b60 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -4,7 +4,7 @@ use rten_tensor::prelude::*; use rten_tensor::{Tensor, TensorView}; use crate::check_dims; -use crate::gemm::{gemm, GemmExecutor, GemmInputA, GemmInputB}; +use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB}; use crate::ops::binary_elementwise::broadcast_shapes; use crate::ops::layout::expand_to; use crate::ops::{InputList, IntoOpResult, OpError, Operator, Output}; @@ -25,6 +25,7 @@ pub struct Gemm { /// /// nb. This is named `gemm_op` to avoid confusion with `gemm::gemm`. pub fn gemm_op( + pool: &TensorPool, a: TensorView, b: TensorView, c: Option, @@ -40,29 +41,42 @@ pub fn gemm_op( let b = if transpose_b { b.transposed() } else { b }; let out_shape = &[a.size(0), b.size(1)][..]; - let mut output = match c { + let gemm = GemmExecutor::new(); + + let output = match c { Some(c) if beta != 0. => { if !c.can_broadcast_to(out_shape) { return Err(OpError::IncompatibleInputShapes( "Cannot broadcast c to output shape", )); } - expand_to(c, out_shape) + let mut output = expand_to(pool, c, out_shape); + let out_row_stride = output.stride(0); + gemm.gemm( + output.data_mut().unwrap(), + out_row_stride, + GemmInputA::Unpacked(a.nd_view()), + GemmInputB::Unpacked(b.nd_view()), + alpha, + beta, + ); + output + } + _ => { + let mut output = pool.alloc(out_shape); + let out_row_stride = output.stride(0); + gemm.gemm_uninit( + output.data_mut().unwrap(), + out_row_stride, + GemmInputA::Unpacked(a.nd_view()), + GemmInputB::Unpacked(b.nd_view()), + alpha, + ); + // Safety: `gemm_uninit` initialized all elements + unsafe { output.assume_init() } } - _ => Tensor::zeros(out_shape), }; - let out_row_stride = output.stride(0); - - gemm( - output.data_mut().unwrap(), - out_row_stride, - a.nd_view(), - b.nd_view(), - alpha, - beta, - ); - Ok(output) } @@ -71,11 +85,12 @@ impl Operator for Gemm { "Gemm" } - fn run(&self, _pool: &TensorPool, inputs: InputList) -> Result, OpError> { + fn run(&self, pool: &TensorPool, inputs: InputList) -> Result, OpError> { let a = inputs.require_as(0)?; let b = inputs.require_as(1)?; let c = inputs.get_as(2)?; gemm_op( + pool, a, b, c, @@ -293,6 +308,8 @@ mod tests { #[test] fn test_gemm_op() -> Result<(), Box> { + let pool = new_pool(); + let mut rng = XorShiftRng::new(1234); let a = Tensor::rand(&[3, 10], &mut rng); let b = Tensor::rand(&[10, 8], &mut rng); @@ -300,7 +317,7 @@ mod tests { let mut expected = Tensor::zeros(&[3, 8]); gemm_tensors(&mut expected, &a, &b, 1., 1.); - let result = gemm_op(a.view(), b.view(), None, 1.0, 1.0, false, false).unwrap(); + let result = gemm_op(&pool, a.view(), b.view(), None, 1.0, 1.0, false, false).unwrap(); expect_equal(&result, &expected)?; @@ -309,6 +326,8 @@ mod tests { #[test] fn test_gemm_op_transposed() -> Result<(), Box> { + let pool = new_pool(); + let mut rng = XorShiftRng::new(1234); let a = Tensor::rand(&[10, 3], &mut rng); let b = Tensor::rand(&[8, 10], &mut rng); @@ -320,7 +339,7 @@ mod tests { let mut expected = Tensor::zeros(&[3, 8]); gemm_tensors(&mut expected, &a_transposed, &b_transposed, 1., 1.); - let result = gemm_op(a.view(), b.view(), None, 1.0, 1.0, true, true).unwrap(); + let result = gemm_op(&pool, a.view(), b.view(), None, 1.0, 1.0, true, true).unwrap(); expect_equal(&result, &expected)?; @@ -329,6 +348,8 @@ mod tests { #[test] fn test_gemm_op_adds_c() -> Result<(), Box> { + let pool = new_pool(); + let mut rng = XorShiftRng::new(1234); let a = Tensor::rand(&[3, 10], &mut rng); let b = Tensor::rand(&[10, 8], &mut rng); @@ -337,7 +358,17 @@ mod tests { let mut expected = c.clone(); gemm_tensors(&mut expected, &a, &b, 1., 1.); - let result = gemm_op(a.view(), b.view(), Some(c.view()), 1.0, 1.0, false, false).unwrap(); + let result = gemm_op( + &pool, + a.view(), + b.view(), + Some(c.view()), + 1.0, + 1.0, + false, + false, + ) + .unwrap(); expect_equal(&result, &expected)?; @@ -346,12 +377,23 @@ mod tests { #[test] fn test_gemm_op_invalid_inputs() { + let pool = new_pool(); + let mut rng = XorShiftRng::new(1234); let a = Tensor::rand(&[3, 10], &mut rng); let b = Tensor::rand(&[10, 8], &mut rng); let c = Tensor::rand(&[3, 5], &mut rng); - let result = gemm_op(a.view(), b.view(), Some(c.view()), 1.0, 1.0, false, false); + let result = gemm_op( + &pool, + a.view(), + b.view(), + Some(c.view()), + 1.0, + 1.0, + false, + false, + ); assert_eq!( result.err(),